From ac5b6d57d2d6ab51c9f0bea4499c1d29c9ea95a7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 18:21:34 -0800 Subject: [PATCH 01/52] add new examples Signed-off-by: yiliu30 --- inc_examples/generate.py | 72 ++++++++++++++++++++++++++++ inc_examples/quantize.py | 98 +++++++++++++++++++++++++++++++++++++++ inc_examples/run_gen.sh | 18 +++++++ inc_examples/run_quant.sh | 25 ++++++++++ 4 files changed, 213 insertions(+) create mode 100644 inc_examples/generate.py create mode 100644 inc_examples/quantize.py create mode 100644 inc_examples/run_gen.sh create mode 100644 inc_examples/run_quant.sh diff --git a/inc_examples/generate.py b/inc_examples/generate.py new file mode 100644 index 000000000..de9781c60 --- /dev/null +++ b/inc_examples/generate.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +try: + from auto_round_extension.vllm_ext import apply as apply_auto_round_extension + apply_auto_round_extension() +except ImportError: + print("auto_round_extension.vllm_ext not found, proceeding without auto-round extension.") + +from vllm import LLM, EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + + +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + + return parser + + +def main(args: dict): + # Pop arguments not used by LLM + max_tokens = args.pop("max_tokens") + temperature = args.pop("temperature") + top_p = args.pop("top_p") + top_k = args.pop("top_k") + + # Create an LLM + llm = LLM(**args) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if temperature is not None: + sampling_params.temperature = temperature + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + +if __name__ == "__main__": + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py new file mode 100644 index 000000000..f8c15ccf8 --- /dev/null +++ b/inc_examples/quantize.py @@ -0,0 +1,98 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +import transformers +import logging +from auto_round import AutoRound +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def quant_model(args): + fp_layers = "shared_experts,lm_head,mlp.gate" + if args.skip_attn: + fp_layers += ",self_attn" + logger.info(f"Using fp_layers: {fp_layers}") + autoround = AutoRound( + model=args.model, + scheme=args.scheme, + enable_torch_compile=args.enable_torch_compile, + iters=args.iters, + fp_layers=fp_layers, + ) + logger.info(f"Save quantized model to {args.output_dir}") + format_type = "auto_round" if args.use_autoround_format else "llm_compressor" + autoround.quantize_and_save( + format=format_type, + output_dir=args.output_dir, + ) + + +if __name__ == "__main__": + import argparse + + # import ar_schemes # Assuming `ar_schemes` is a module in your project + import auto_round.schemes as ar_schemes + + # Define available schemes + AVAILABLE_SCHEMES = { + "MXFP8": "MXFP8", + "FP8_STATIC": ar_schemes.FP8_STATIC, + "MXFP8_AR": ar_schemes.MXFP8, + "MXFP4_AR": ar_schemes.MXFP4, + "MXFP4": "MXFP4", + "W4A16": "W4A16", + "NVFP4": ar_schemes.NVFP4, + } + + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Select a quantization scheme.") + parser.add_argument( + "--model", + type=str, + help="Path to the pre-trained model or model identifier from Hugging Face Hub.", + ) + parser.add_argument( + "--scheme", + type=str, + choices=AVAILABLE_SCHEMES.keys(), + default="MXFP4", + help="Quantization scheme to use. Available options: " + ", ".join(AVAILABLE_SCHEMES.keys()), + ) + + parser.add_argument( + "--enable_torch_compile", + action="store_true", + help="Enable torch compile for the model.", + ) + parser.add_argument( + "--use_autoround_format", + action="store_true", + help="Use AutoRound format for saving the quantized model.", + ) + + parser.add_argument( + "--skip_attn", + action="store_true", + help="Skip quantize attention layers.", + ) + parser.add_argument( + "--iters", + type=int, + default=0, + help="Number of iterations for quantization.", + ) + # output_dir can also be added as an argument if needed + parser.add_argument( + "--output_dir", + type=str, + default="quantized_model", + help="Directory to save the quantized model.", + ) + + args = parser.parse_args() + + # Set the scheme based on user input + scheme = AVAILABLE_SCHEMES[args.scheme] + + # Print the selected scheme for confirmation + logger.info(f"Selected quantization scheme: {args.scheme}") + quant_model(args) diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh new file mode 100644 index 000000000..15cea5ebc --- /dev/null +++ b/inc_examples/run_gen.sh @@ -0,0 +1,18 @@ +export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_ENABLE_V1_MULTIPROCESSING=0 + + + +model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" +model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" +model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" + + +# VLLM_ATTENTION_BACKEND=TRITON_ATTN \ +VLLM_USE_DEEP_GEMM=0 \ +VLLM_LOGGING_LEVEL=DEBUG \ +VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +python generate.py \ + --model ${model_path} \ + --max-tokens 64 \ + --enforce-eager \ No newline at end of file diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh new file mode 100644 index 000000000..e4824b7d6 --- /dev/null +++ b/inc_examples/run_quant.sh @@ -0,0 +1,25 @@ + +export AR_LOG_LEVEL=TRACE +model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" +# model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" +base_name=$(basename ${model}) +scheme="MXFP4" +# scheme="MXFP8" +qmodel_dir="quantized_models/" +mkdir -p ${qmodel_dir} +output_dir="${qmodel_dir}/${base_name}-${scheme}" +python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format +# model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" + +# scheme="MXFP4" + +# output_path="./" +# base_name=$(basename ${model_name}) +# CUDA_VISIBLE_DEVICES=$device \ +# python3 quantize.py \ +# --model ${model} \ +# --scheme ${scheme} \ +# --format llm_compressor \ +# --iters 0 \ +# --enable_torch_compile \ +# --output_dir ${output_path}/${base_name}-${scheme} \ No newline at end of file From 2a2e83485bfb0be19a15522caf4226fb6d767672 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 19:32:40 -0800 Subject: [PATCH 02/52] fix mxfp4 moe for qwen Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/__init__.py | 2 +- auto_round_extension/vllm_ext/envs_ext.py | 1 + auto_round_extension/vllm_ext/moe_impl_mxfp4.py | 10 ++++++---- inc_examples/run_gen.sh | 8 ++++++-- inc_examples/run_quant.sh | 1 + 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 3e145334f..bccfb4c4e 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -23,4 +23,4 @@ def apply(): from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig - from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables + from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables \ No newline at end of file diff --git a/auto_round_extension/vllm_ext/envs_ext.py b/auto_round_extension/vllm_ext/envs_ext.py index 325b9de7c..845854cbc 100644 --- a/auto_round_extension/vllm_ext/envs_ext.py +++ b/auto_round_extension/vllm_ext/envs_ext.py @@ -24,6 +24,7 @@ "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"), "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"), "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"), + "VLLM_AR_POST_PROCESS_GPTOSS": lambda: os.getenv("VLLM_AR_POST_PROCESS_GPTOSS", "0") in ("1", "true", "True"), } # Add the extra environment variables to vllm.envs import vllm.envs as envs diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index 77ec6570d..754840768 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -250,8 +250,10 @@ def revert_interleaved_bias(bias): return revert_bias # breakpoint() - w13_bias_swapped = revert_interleaved_bias(layer.w13_bias) - layer.w13_bias.data.copy_(w13_bias_swapped) + if self.has_bias: + if envs.VLLM_AR_POST_PROCESS_GPTOSS: + w13_bias_swapped = revert_interleaved_bias(layer.w13_bias) + layer.w13_bias.data.copy_(w13_bias_swapped) if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: @@ -271,8 +273,8 @@ def revert_interleaved_w1(w1): new_w1[:, ::2, :] = w1[:, : N // 2, :] new_w1[:, 1::2, :] = w1[:, N // 2 :, :] return new_w1 - - w1 = revert_interleaved_w1(w1) + if envs.VLLM_AR_POST_PROCESS_GPTOSS: + w1 = revert_interleaved_w1(w1) w1_scale = None w2 = layer.w2_weight_packed diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 15cea5ebc..5bd8014df 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -5,14 +5,18 @@ export VLLM_ENABLE_V1_MULTIPROCESSING=0 model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" -model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" +# model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" +# /home/yiliu7/workspace/torchutils/examples + # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ +VLLM_AR_MXFP4_MODULAR_MOE=1 \ +VLLM_ENABLE_STATIC_MOE=0 \ VLLM_USE_DEEP_GEMM=0 \ VLLM_LOGGING_LEVEL=DEBUG \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ python generate.py \ --model ${model_path} \ - --max-tokens 64 \ + --max-tokens 16 \ --enforce-eager \ No newline at end of file diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index e4824b7d6..3409ad368 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -1,6 +1,7 @@ export AR_LOG_LEVEL=TRACE model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" +model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" # model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" base_name=$(basename ${model}) scheme="MXFP4" From f0f0e1d4740994ac8be8741fbbf85f865a7378ee Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 30 Oct 2025 23:27:36 -0400 Subject: [PATCH 03/52] add mxfp8 Signed-off-by: yiliu30 --- .../vllm_ext/auto_round_ext.py | 3 +- .../vllm_ext/linear_impl_mxfp8.py | 129 +++++++++ .../vllm_ext/mxfp8_qdq_utils.py | 65 +++++ auto_round_extension/vllm_ext/quant_impl.py | 45 +++ .../vllm_ext/quant_method_linear.py | 109 ++++++++ .../vllm_ext/quant_method_moe.py | 6 +- .../vllm_ext/tests/test_mxfp4_moe.py | 4 +- .../vllm_ext/torchao_patch.py | 256 ++++++++++++++++++ auto_round_extension/vllm_ext/utils.py | 12 + 9 files changed, 622 insertions(+), 7 deletions(-) create mode 100644 auto_round_extension/vllm_ext/linear_impl_mxfp8.py create mode 100644 auto_round_extension/vllm_ext/mxfp8_qdq_utils.py create mode 100644 auto_round_extension/vllm_ext/quant_impl.py create mode 100644 auto_round_extension/vllm_ext/quant_method_linear.py create mode 100644 auto_round_extension/vllm_ext/torchao_patch.py diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index cce2c912d..23d441daa 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme +from auto_round_extension.vllm_ext.quant_method_linear import AutoRoundQuantLinearMethod from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod logger = init_logger(__name__) @@ -36,7 +37,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str): quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix) return quant_method elif isinstance(layer, LinearBase): - return UnquantizedLinearMethod() + return AutoRoundQuantLinearMethod.get_method(self, layer, prefix) else: return None diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp8.py b/auto_round_extension/vllm_ext/linear_impl_mxfp8.py new file mode 100644 index 000000000..0c6e1998a --- /dev/null +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp8.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import torch +import vllm.envs as envs +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) + +from auto_round_extension.vllm_ext.mxfp8_qdq_utils import dequant_mx_fp8, quant_mx_fp8 +from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl + + +class AutoRoundMXFP8LinearImpl(AutoRoundQuantImpl): + def __init__(self, quant_scheme): + self.quant_scheme = quant_scheme + self.strategy = "TENSOR_GROUP" + self.out_dtype = torch.get_default_dtype() + self.is_static_input_scheme = False + self.group_size = 32 + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + def process_weights_after_loading(self, layer) -> None: + return + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + # maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.strategy == "TENSOR_GROUP": + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.uint8, # E8M0 for MXFP8 scale + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + else: + raise NotImplementedError(f"Strategy {self.strategy} is not supported for W8A8-MXFp8") + + # min requirement for fp8 kernels + # weight_scale[:] = torch.finfo(torch.float32).min + # weight_scale.fill_(torch.finfo(torch.float32).min) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # dequant weight + weight = layer.weight + weight_scale = layer.weight_scale + dequnat_weight = dequant_mx_fp8( + weight_fp8=weight.data, + scale_e8m0=weight_scale.data, + block_size=self.group_size, + ) + dequnat_weight = dequnat_weight.to(x.dtype) + # if not envs.VLLM_AR_MXFP8_DISABLE_INPUT_QDQ: + # q-dq input + x_scale, x_quant = quant_mx_fp8(x) + dequant_x = dequant_mx_fp8( + weight_fp8=x_quant, + scale_e8m0=x_scale, + block_size=self.group_size, + ) + x = dequant_x.to(x.dtype) + + out = x @ dequnat_weight.t() + return out.to(x.dtype) + (bias if bias is not None else 0) diff --git a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py new file mode 100644 index 000000000..aa23ff9fd --- /dev/null +++ b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +__all__ = ["get_fp_scale", "dequant_mx_fp8", "quant_mx_fp8"] + + +# def get_fp_scale(scale_e8m0): +# # https://github.com/pytorch/ao/blob/994a4ba6c869854fcaa6ca7e118fcbd75e6c28cc/torchao/prototype/mx_formats/mx_tensor.py#L337 +# assert scale_e8m0.dtype == torch.uint8, f"Expected uint8, got {scale_e8m0.dtype}" +# E8M0_EXPONENT_BIAS = 127 +# scale_e8m0 = scale_e8m0.view(torch.uint8) +# s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS +# # TODO(later): it would be nice if there was a way to do the 2^x operation +# # in PyTorch without creating a tensor of twos +# two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) +# # pow(two, s_offset) can be out of range of floating point formats. +# # TODO(later): handle this for float16 if we decide to support float16 +# # scales. +# s_fp = torch.pow(two, s_offset) + +# return s_fp + + +def get_fp_scale(scale_e8m0): + # https://github.com/pytorch/ao/blob/994a4ba6c869854fcaa6ca7e118fcbd75e6c28cc/torchao/prototype/mx_formats/mx_tensor.py#L337 + E8M0_EXPONENT_BIAS = 127 + + scale_e8m0 = scale_e8m0.view(torch.uint8) + s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS + # TODO(later): it would be nice if there was a way to do the 2^x operation + # in PyTorch without creating a tensor of twos + # two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) + # pow(two, s_offset) can be out of range of floating point formats. + # TODO(later): handle this for float16 if we decide to support float16 + # scales. + # s_fp = torch.pow(two, s_offset) + # !!!!NOTE Critical: fixed the OoM issue when using HPU graph + s_fp = torch.pow(2.0, s_offset.to(torch.float)) + + return s_fp + + +def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size): + scale_float = get_fp_scale(scale_e8m0) + weight_bf16 = weight_fp8.to(torch.bfloat16) + weight_original_shape = weight_bf16.shape + weight_bf16 = weight_bf16.reshape(-1, block_size) + scale_float = scale_float.reshape(-1, 1) + dequant_weight = weight_bf16 * scale_float + dequant_weight = dequant_weight.reshape(weight_original_shape) + return dequant_weight + + +def quant_mx_fp8(tensor): + from .torchao_patch import ScaleCalculationMode, to_mx + + scale_e8m0_biased, data_lp = to_mx( + data_hp=tensor, + elem_dtype=torch.float8_e4m3fn, + block_size=32, + scaling_mode=ScaleCalculationMode.RCEIL, + pack_fp6=False, + ) + return scale_e8m0_biased, data_lp diff --git a/auto_round_extension/vllm_ext/quant_impl.py b/auto_round_extension/vllm_ext/quant_impl.py new file mode 100644 index 000000000..b88843462 --- /dev/null +++ b/auto_round_extension/vllm_ext/quant_impl.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class AutoRoundQuantImpl(ABC): + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ): + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + raise NotImplementedError diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py new file mode 100644 index 000000000..73f79cb65 --- /dev/null +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig + +from auto_round.schemes import QuantizationScheme +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8 + +logger = init_logger(__name__) + + +QLINEAR_METHODS_DISPATCH_TABLE = {} + + +class AutoRoundQuantLinearMethod(LinearMethodBase): + + def __init__(self, impl, config=None, scheme=None): + self.config = config + self.impl = impl + self.scheme = scheme + + @staticmethod + def get_method( + quant_config: AutoRoundConfig, + layer: torch.nn.Module, + prefix: str, + ) -> "AutoRoundQuantLinearMethod": + + def get_scheme(quant_config: AutoRoundConfig, prefix: str): + # Check extra_config first + layer_schemes = quant_config.layer_schemes + # FIXME: make more robust + for name, scheme in layer_schemes.items(): + if prefix.startswith(name): + return scheme + # If not found, use default + return quant_config.quant_scheme + + def check_quantized(weight_bits: int) -> bool: + return weight_bits < 16 + + def get_impl(scheme: QuantizationScheme): + if not check_quantized(scheme.bits): + + return UnquantizedLinearMethod() + + elif _is_mxfp8_w8a8(scheme): + from auto_round_extension.vllm_ext.linear_impl_mxfp8 import AutoRoundMXFP8LinearImpl + + return AutoRoundMXFP8LinearImpl(quant_config) + + raise ValueError(f"Unsupported Linear scheme: {scheme}") + + layer_scheme = get_scheme(quant_config, prefix) + impl = get_impl(layer_scheme) + logger.debug("Apply %s to %s", impl.__class__.__name__, prefix) + return AutoRoundQuantLinearMethod(impl=impl) + + @classmethod + def get_min_capability(cls) -> int: + return cls.impl.get_min_capability() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + return self.impl.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return self.impl.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + return self.impl.apply_weights(layer, x, bias=bias) diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index ae7b81775..dfb20116f 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4 logger = init_logger(__name__) @@ -31,11 +32,6 @@ QMOE_METHODS_DISPATCH_TABLE = {} -def _is_mxfp4_w4a4(scheme: QuantizationScheme): - # FIXME: below impl is incomplete - return scheme.bits == 4 and scheme.group_size == 32 - - class AutoRoundMoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) diff --git a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py index a2de32832..df748eb10 100644 --- a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py +++ b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py @@ -18,7 +18,9 @@ MODELS = [ # "/data5/yliu7/HF_HOME/unsloth-gpt-oss-20b-BF16-ar-MXFP4/" # "/data5/yliu7/HF_HOME/Qwen2.5-0.5B-Instruct-test-FP8_STATIC-fp8kv/" - "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4" + # "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4", + # "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP4-fp8attention", + "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP8" ] diff --git a/auto_round_extension/vllm_ext/torchao_patch.py b/auto_round_extension/vllm_ext/torchao_patch.py new file mode 100644 index 000000000..d5b70c6cf --- /dev/null +++ b/auto_round_extension/vllm_ext/torchao_patch.py @@ -0,0 +1,256 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, auto +from typing import Union + +import torch + +from .utils import _to_mx_rceil, get_fp_scale + + +class ScaleCalculationMode(Enum): + """ + Enum representing the different methods for calculating MX block scaling. + There are three methods available: + FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). + It result in overflow issues for large values and bad for gradient quantization. + CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. + It uses X = 2^ceil(log2(max_abs(v))-max_exp). + EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). + It provides better accuracy for MX4 training compared to FLOOR and CEIL. + RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos. + This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization + Section "Computing scaling and conversion factors for FP8 with UE8M0 scales" + + By default, we use the EVEN method for better accuracy. + """ + + FLOOR = auto() + CEIL = auto() + EVEN = auto() + RCEIL = auto() + + +# This is conceptually an enum of non-core dtypes +# TODO(future PR): change to a cleaner way to represent this without +# regressing torch.compile and while keeping things readable. +DTYPE_FP6_E3M2 = "fp6_e3m2" +DTYPE_FP6_E2M3 = "fp6_e2m3" + +# Supported element dtypes +# TODO(future PR): add support for MX int8 +SUPPORTED_ELEM_DTYPES = [ + torch.float8_e4m3fn, + torch.float8_e5m2, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, +] + + +F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 +F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0 + +F8E4M3_MAX_POW2 = 8 # 256 +F8E5M2_MAX_POW2 = 15 # 32768 +F6_E2M3_MAX_POW2 = 2 # 4 +F6_E3M2_MAX_POW2 = 4 # 16 +F4_E2M1_MAX_POW2 = 2 # 4 + +E8M0_EXPONENT_BIAS = 127 +E8M0_EXPONENT_NAN_VAL = 255 + +F32_EXP_BIAS = 127 +BF16_EXP_BIAS = 127 +F6_E2M3_EXP_BIAS = 1 +F6_E3M2_EXP_BIAS = 3 +F4_E2M1_EXP_BIAS = 1 + +F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) + +F6_E2M3_MAX = 7.5 +F6_E2M3_MIN_NORMAL = 1.0 +F6_E2M3_MAX_INT = 31 # integer corresponding to 0b00011111 + +F6_E3M2_MAX = 28.0 +F6_E3M2_MIN_NORMAL = 0.25 +F6_E3M2_MAX_INT = 31 # integer corresponding to 0b00011111 + +F4_E2M1_MAX = 6.0 +F4_E2M1_MIN_NORMAL = 1.0 +F4_E2M1_MAX_INT = 7 + +BLOCK_SIZE_DEFAULT = 32 + + +# TODO(later): read from somewhere else? +SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 +EBITS_BF16, MBITS_BF16 = 8, 7 +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 +EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 +EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 +EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 +EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 + + +def to_mx( + data_hp: torch.Tensor, + elem_dtype: Union[torch.dtype, str], + block_size: int, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, + pack_fp6: bool = False, +): + """ + Takes a high precision tensor and converts to MX scale and raw data, in + naive layout (scale and raw data are separate tensors). + """ + + assert data_hp.dtype in ( + torch.bfloat16, + torch.float, + ), f"{data_hp.dtype} is not supported yet" + # TODO(future PR): consider supporting padding + assert data_hp.numel() % block_size == 0, "unsupported" + assert data_hp.is_contiguous(), "unsupported" + assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" + + # calculate the scale in e8m0 format + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(-1, block_size) + + # find max value of the data + # Note: this only implements the `minimally supported` version of + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + # section 6.3. + max_abs = torch.amax(torch.abs(data_hp), 1) + + # Add an epsilon to prevent the log2 function call for returning -inf + # where the values are zero. + eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) + + # Set X to be the largest power-of-two less than or equal to + # max_abs(v), divided by the largest power of two representable + # in the element data type, and get the mbits at the same time + if elem_dtype == torch.float8_e4m3fn: + target_max_pow2 = F8E4M3_MAX_POW2 + mbits = MBITS_F8_E4M3 + max_pos = F8E4M3_MAX + else: + raise AssertionError("unsupported element dtype") + + if scaling_mode == ScaleCalculationMode.RCEIL: + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + else: + if data_hp.dtype is torch.float32: + hp_int_dtype = torch.int32 + hp_mbits = MBITS_F32 + hp_ebits = EBITS_F32 + hp_exp_bias = F32_EXP_BIAS + else: + assert data_hp.dtype is torch.bfloat16 + hp_int_dtype = torch.int16 + hp_mbits = MBITS_BF16 + hp_ebits = EBITS_BF16 + hp_exp_bias = BF16_EXP_BIAS + + # rounding before calculating the largest power of 2 + # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) + if scaling_mode == ScaleCalculationMode.EVEN: + nan_mask = torch.isnan(max_abs) + max_abs = max_abs.view(hp_int_dtype) + val_to_add = 1 << (hp_mbits - mbits - 1) + mask = ((1 << (hp_ebits + SBITS)) - 1) << hp_mbits + max_abs = (max_abs + val_to_add) & mask + max_abs = max_abs.view(data_hp.dtype) + max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device, dtype=max_abs.dtype) + + # Calculate the scale for different modes + max_abs_int32 = (max_abs + eps).view(hp_int_dtype) + extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias + + if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): + scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 + elif scaling_mode == ScaleCalculationMode.CEIL: + # round up: add one to scale if the mantissa is larger than 0 + # 0x7FFFFF is equal to 23 ones + mantissa_gt_one = (max_abs_int32 & 0x7FFFFF) > 0 + extracted_pow2 += mantissa_gt_one + scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 + else: + raise AssertionError("unsupported scaling calculation mode") + + # Clamp to exponents that can be represented in e8m0 + # add one to positive range to capture NaNs + scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1) + + # Create the biased e8m0 representation and cast it to 8 bits + scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS + scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8) + + # Conversion to torch.uint8 sets NaN values to 0, fix this by + # explicitly setting known NaN values to 255 + scale_e8m0_biased = torch.where( + torch.isnan(max_abs), + E8M0_EXPONENT_NAN_VAL, + scale_e8m0_biased, + ) + + # For now, calculate the scale in floating point. + scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view(torch.float32) + + # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the + # float32 denormal range. For now, manually adjust the fp scale. This is + # relevant if all of the incoming block values are zeroes. + # See https://github.com/pytorch/pytorch/issues/125557 for details. + # Note: it would be more correct to set the minimum to 2**-127, but this + # does not work in triton either as it looks like subnormal value handling + # has some gaps. So, for now just set to the minimum normal value. + scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) + + # scale and saturated cast the data elements to max of target dtype + data_lp = data_hp / scale_fp32.unsqueeze(1) + + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) and not torch._dynamo.is_compiling(): + # As of 20250317, the Pytorch eager mode cast to `torch.float8_e4m3fn` + # is unsaturated. This cast is saturated in triton. If we are compute bound, + # we see a speedup if we remove this redundant clamp if we are compiling + # to triton. + # TODO(#1912): make the saturated cast work in eager mode and remove this + # workaround. + data_lp = torch.clamp(data_lp, min=-1 * max_pos, max=max_pos) + + # cast to target dtype + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + data_lp = data_lp.to(elem_dtype) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) + else: + raise AssertionError("unsupported") + + # scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + return scale_e8m0_biased, data_lp + + +def down_size(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + + +def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) diff --git a/auto_round_extension/vllm_ext/utils.py b/auto_round_extension/vllm_ext/utils.py index f98eb2325..77a5bd3f1 100644 --- a/auto_round_extension/vllm_ext/utils.py +++ b/auto_round_extension/vllm_ext/utils.py @@ -14,10 +14,22 @@ import torch +from auto_round.schemes import QuantizationScheme + E8M0_EXPONENT_BIAS = 127 E8M0_EXPONENT_NAN_VAL = 255 +def _is_mxfp4_w4a4(scheme: QuantizationScheme): + # FIXME: below impl is incomplete + return scheme.bits == 4 and scheme.group_size == 32 + + +def _is_mxfp8_w8a8(scheme: QuantizationScheme): + # FIXME: below impl is incomplete + return scheme.bits == 8 and scheme.group_size == 32 + + def get_fp_scale(scale_e8m0): scale_e8m0 = scale_e8m0.view(torch.uint8) s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS From ed23ef7c17b96b709aa55b5768fe555e1ff4ca82 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 30 Oct 2025 23:28:38 -0400 Subject: [PATCH 04/52] rename test Signed-off-by: yiliu30 --- .../vllm_ext/tests/{test_mxfp4_moe.py => test_models.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename auto_round_extension/vllm_ext/tests/{test_mxfp4_moe.py => test_models.py} (100%) diff --git a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py b/auto_round_extension/vllm_ext/tests/test_models.py similarity index 100% rename from auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py rename to auto_round_extension/vllm_ext/tests/test_models.py From edba1ee7a2fcceb055c6b6547623a175c6ad2c72 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 30 Oct 2025 23:33:23 -0400 Subject: [PATCH 05/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/mxfp8_qdq_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py index aa23ff9fd..1250180b6 100644 --- a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py @@ -53,7 +53,7 @@ def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size): def quant_mx_fp8(tensor): - from .torchao_patch import ScaleCalculationMode, to_mx + from auto_round_extension.vllm_ext.torchao_patch import ScaleCalculationMode, to_mx scale_e8m0_biased, data_lp = to_mx( data_hp=tensor, From d3d13b8dbaa36b8a083f401daf74109d5d93a26b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 30 Oct 2025 23:35:20 -0400 Subject: [PATCH 06/52] clean Signed-off-by: yiliu30 --- .../vllm_ext/quant_method_linear.py | 15 +-------------- auto_round_extension/vllm_ext/quant_method_moe.py | 15 +-------------- auto_round_extension/vllm_ext/utils.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 28 deletions(-) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index 73f79cb65..5cc67db83 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8 +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, check_quantized, get_scheme logger = init_logger(__name__) @@ -42,19 +42,6 @@ def get_method( prefix: str, ) -> "AutoRoundQuantLinearMethod": - def get_scheme(quant_config: AutoRoundConfig, prefix: str): - # Check extra_config first - layer_schemes = quant_config.layer_schemes - # FIXME: make more robust - for name, scheme in layer_schemes.items(): - if prefix.startswith(name): - return scheme - # If not found, use default - return quant_config.quant_scheme - - def check_quantized(weight_bits: int) -> bool: - return weight_bits < 16 - def get_impl(scheme: QuantizationScheme): if not check_quantized(scheme.bits): diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index dfb20116f..36bd57aa4 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4 +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, check_quantized, get_scheme logger = init_logger(__name__) @@ -43,19 +43,6 @@ def get_moe_method( prefix: str, ) -> "AutoRoundMoEMethod": - def get_scheme(quant_config: AutoRoundConfig, prefix: str): - # Check extra_config first - layer_schemes = quant_config.layer_schemes - # FIXME: make more robust - for name, scheme in layer_schemes.items(): - if prefix.startswith(name): - return scheme - # If not found, use default - return quant_config.quant_scheme - - def check_quantized(weight_bits: int) -> bool: - return weight_bits < 16 - def get_impl(scheme: QuantizationScheme): if not check_quantized(scheme.bits): from vllm.model_executor.layers.fused_moe.layer import ( diff --git a/auto_round_extension/vllm_ext/utils.py b/auto_round_extension/vllm_ext/utils.py index 77a5bd3f1..80a564853 100644 --- a/auto_round_extension/vllm_ext/utils.py +++ b/auto_round_extension/vllm_ext/utils.py @@ -20,6 +20,21 @@ E8M0_EXPONENT_NAN_VAL = 255 +def get_scheme(quant_config, prefix: str): + # Check extra_config first + layer_schemes = quant_config.layer_schemes + # FIXME: make more robust + for name, scheme in layer_schemes.items(): + if prefix.startswith(name): + return scheme + # If not found, use default + return quant_config.quant_scheme + + +def check_quantized(weight_bits: int) -> bool: + return weight_bits < 16 + + def _is_mxfp4_w4a4(scheme: QuantizationScheme): # FIXME: below impl is incomplete return scheme.bits == 4 and scheme.group_size == 32 From dd909e453b524ab003250426938c00780fbd0bad Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 21:05:13 -0800 Subject: [PATCH 07/52] fix linear Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/quant_method_linear.py | 4 +++- inc_examples/quantize.py | 1 + inc_examples/run_gen.sh | 4 +++- inc_examples/run_quant.sh | 3 ++- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index 5cc67db83..c9d0ff6f4 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -41,6 +41,8 @@ def get_method( layer: torch.nn.Module, prefix: str, ) -> "AutoRoundQuantLinearMethod": + # FIXME: revert this WA after fixing scheme matching issue + return UnquantizedLinearMethod() def get_impl(scheme: QuantizationScheme): if not check_quantized(scheme.bits): @@ -52,7 +54,7 @@ def get_impl(scheme: QuantizationScheme): return AutoRoundMXFP8LinearImpl(quant_config) - raise ValueError(f"Unsupported Linear scheme: {scheme}") + raise ValueError(f"Unsupported Linear scheme: {scheme}, layer: {prefix}") layer_scheme = get_scheme(quant_config, prefix) impl = get_impl(layer_scheme) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index f8c15ccf8..2a711368d 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -10,6 +10,7 @@ def quant_model(args): fp_layers = "shared_experts,lm_head,mlp.gate" if args.skip_attn: fp_layers += ",self_attn" + # fp_layers += ",layers.0" logger.info(f"Using fp_layers: {fp_layers}") autoround = AutoRound( model=args.model, diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 5bd8014df..4dc5f5712 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -5,6 +5,7 @@ export VLLM_ENABLE_V1_MULTIPROCESSING=0 model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" +model_path="quantized_models/Qwen3-235B-A22B-MXFP4" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" @@ -19,4 +20,5 @@ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ python generate.py \ --model ${model_path} \ --max-tokens 16 \ - --enforce-eager \ No newline at end of file + --enforce-eager \ + --tensor_parallel_size 4 \ No newline at end of file diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index 3409ad368..b707515fb 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -2,7 +2,8 @@ export AR_LOG_LEVEL=TRACE model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" -# model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" +model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" +model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" base_name=$(basename ${model}) scheme="MXFP4" # scheme="MXFP8" From d2ed6a79c3d79efed41b4cc412d5437a95b37cd6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 21:32:19 -0800 Subject: [PATCH 08/52] fix gate_up proj match Signed-off-by: yiliu30 --- .../vllm_ext/quant_method_linear.py | 23 +++++++++++-------- .../vllm_ext/quant_method_moe.py | 4 ++-- auto_round_extension/vllm_ext/utils.py | 2 +- inc_examples/run_gen.sh | 8 +++---- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index c9d0ff6f4..b067e4828 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, check_quantized, get_scheme +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, need_quantize, get_scheme logger = init_logger(__name__) @@ -41,24 +41,29 @@ def get_method( layer: torch.nn.Module, prefix: str, ) -> "AutoRoundQuantLinearMethod": - # FIXME: revert this WA after fixing scheme matching issue - return UnquantizedLinearMethod() def get_impl(scheme: QuantizationScheme): - if not check_quantized(scheme.bits): - - return UnquantizedLinearMethod() - - elif _is_mxfp8_w8a8(scheme): + if _is_mxfp8_w8a8(scheme): from auto_round_extension.vllm_ext.linear_impl_mxfp8 import AutoRoundMXFP8LinearImpl return AutoRoundMXFP8LinearImpl(quant_config) + elif _is_mxfp4_w4a4(scheme): + from auto_round_extension.vllm_ext.linear_impl_mxfp4 import AutoRoundMXFP4LinearImpl + + return AutoRoundMXFP4LinearImpl(quant_config) + raise ValueError(f"Unsupported Linear scheme: {scheme}, layer: {prefix}") + # TODO: use a more robust way to map layer names + if prefix.endswith("gate_up_proj"): + # update gate_up_proj to gate_proj + prefix = prefix.replace("gate_up_proj", "gate_proj") layer_scheme = get_scheme(quant_config, prefix) + if not need_quantize(layer_scheme.bits): + return UnquantizedLinearMethod() impl = get_impl(layer_scheme) - logger.debug("Apply %s to %s", impl.__class__.__name__, prefix) + logger.info("Apply %s to %s", impl.__class__.__name__, prefix) return AutoRoundQuantLinearMethod(impl=impl) @classmethod diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index 36bd57aa4..fdc695259 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, check_quantized, get_scheme +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, need_quantize, get_scheme logger = init_logger(__name__) @@ -44,7 +44,7 @@ def get_moe_method( ) -> "AutoRoundMoEMethod": def get_impl(scheme: QuantizationScheme): - if not check_quantized(scheme.bits): + if not need_quantize(scheme.bits): from vllm.model_executor.layers.fused_moe.layer import ( UnquantizedFusedMoEMethod, ) diff --git a/auto_round_extension/vllm_ext/utils.py b/auto_round_extension/vllm_ext/utils.py index 80a564853..f499a1c02 100644 --- a/auto_round_extension/vllm_ext/utils.py +++ b/auto_round_extension/vllm_ext/utils.py @@ -31,7 +31,7 @@ def get_scheme(quant_config, prefix: str): return quant_config.quant_scheme -def check_quantized(weight_bits: int) -> bool: +def need_quantize(weight_bits: int) -> bool: return weight_bits < 16 diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 4dc5f5712..9bb63ad03 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -5,7 +5,7 @@ export VLLM_ENABLE_V1_MULTIPROCESSING=0 model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" -model_path="quantized_models/Qwen3-235B-A22B-MXFP4" +# model_path="quantized_models/Qwen3-235B-A22B-MXFP4" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" @@ -16,9 +16,9 @@ VLLM_AR_MXFP4_MODULAR_MOE=1 \ VLLM_ENABLE_STATIC_MOE=0 \ VLLM_USE_DEEP_GEMM=0 \ VLLM_LOGGING_LEVEL=DEBUG \ -VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +VLLM_ENABLE_V1_MULTIPROCESSING=0 \ python generate.py \ --model ${model_path} \ --max-tokens 16 \ - --enforce-eager \ - --tensor_parallel_size 4 \ No newline at end of file + # --enforce-eager \ + # --tensor_parallel_size 4 \ No newline at end of file From 790a7203ebf62fe3063285f4fcddfc61132b4fc1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 21:35:23 -0800 Subject: [PATCH 09/52] add mxfp4 Signed-off-by: yiliu30 --- .../vllm_ext/linear_impl_mxfp4.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 auto_round_extension/vllm_ext/linear_impl_mxfp4.py diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py new file mode 100644 index 000000000..45ead99f5 --- /dev/null +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +import vllm.envs as envs +from vllm.logger import init_logger + + +from vllm.model_executor.parameter import GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter + +from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( + dequant_mxfp4_to_fp8, + mxfp4_gemm_with_unpacked_weight, + run_mxfp4_emulations, +) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = ["AutoRoundMXFP4LinearImpl"] + +from auto_round_extension.vllm_ext.quant_impl import AutoRoundQuantImpl + + +class AutoRoundMXFP4LinearImpl(AutoRoundQuantImpl): + def __init__(self, quant_scheme): + self.quant_scheme = quant_scheme + self.group_size = 32 + + @classmethod + def get_min_capability(cls) -> int: + if envs.VLLM_USE_MXFP4_CT_EMULATIONS: + return 80 + return 100 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter( + data=torch.empty(sum(output_partition_sizes), input_size_per_partition // 2, dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_packed", weight) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + # dtype=torch.uint8, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer) -> None: + # FIXME: may dequant to bf16 + if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( + dequant_mxfp4_to_fp8, + mxfp4_gemm_with_unpacked_weight, + run_mxfp4_emulations, + ) + + weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( + data_lp=layer.weight_packed, + scale_e8m0=layer.weight_scale, + ) + del layer.weight_packed + del layer.weight_scale + layer.weight_packed = None + layer.weight_scale = None + layer.register_parameter( + "weight_unpacked_fp8", + torch.nn.Parameter( + weight_fp8, + requires_grad=False, + ), + ) + layer.register_parameter( + "weight_scale_bf16", + torch.nn.Parameter( + scale_bf16, + requires_grad=False, + ), + ) + + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + out = run_mxfp4_emulations(x=x, weight=layer.weight_packed, weight_scale=layer.weight_scale) + if bias is not None: + out = out + bias + return out + else: + out = mxfp4_gemm_with_unpacked_weight( + x=x, + weight_fp8=layer.weight_unpacked_fp8, + weight_scale_bf16=layer.weight_scale_bf16, + bias=bias, + ) + return out + From 2ad5558cb72ec851bfcd7db7d218723a27428a24 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 17:53:39 -0800 Subject: [PATCH 10/52] add recipes Signed-off-by: yiliu30 --- inc_examples/quantize.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index 2a711368d..013cef33c 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -6,6 +6,29 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + + +recipes = { + "ds_mxfp8": { + "scheme": "MXFP8", + "fp_layers": "lm_head,mlp.gate", + }, + "ds_mxfp4": { + "scheme": "MXFP4", + "fp_layers": "lm_head,mlp.gate,self_attn", + }, + "qwen_mxfp8": { + "scheme": "MXFP8", + "fp_layers": "lm_head,mlp.gate", + }, + "qwen_mxfp4": { + "scheme": "MXFP4", + "fp_layers": "lm_head,mlp.gate,self_attn", + "iters": 200, + }, +} + + def quant_model(args): fp_layers = "shared_experts,lm_head,mlp.gate" if args.skip_attn: From 553529ae77f0d5e64dc9cc9853aa28eda38c8e5a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 21:41:47 -0800 Subject: [PATCH 11/52] fix qwen mxfp4 Signed-off-by: yiliu30 --- .../vllm_ext/quant_method_linear.py | 5 +++- inc_examples/quantize.py | 30 ++++++++++++------- inc_examples/run_gen.sh | 2 ++ inc_examples/run_quant.sh | 8 +++-- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index b067e4828..6721b1d40 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -57,8 +57,11 @@ def get_impl(scheme: QuantizationScheme): # TODO: use a more robust way to map layer names if prefix.endswith("gate_up_proj"): - # update gate_up_proj to gate_proj + # update gate_up_proj to gate_proj, assume both gate and up share the same quantization scheme prefix = prefix.replace("gate_up_proj", "gate_proj") + if prefix.endswith("qkv_proj"): + # update qkv_proj to q_proj, assume all qkv share the same quantization scheme + prefix = prefix.replace("qkv_proj", "q_proj") layer_scheme = get_scheme(quant_config, prefix) if not need_quantize(layer_scheme.bits): return UnquantizedLinearMethod() diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index 013cef33c..1be7c2033 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -8,45 +8,46 @@ -recipes = { +topologies_config = { "ds_mxfp8": { "scheme": "MXFP8", "fp_layers": "lm_head,mlp.gate", + "iters": 0, }, "ds_mxfp4": { "scheme": "MXFP4", "fp_layers": "lm_head,mlp.gate,self_attn", + "iters": 0, }, "qwen_mxfp8": { "scheme": "MXFP8", "fp_layers": "lm_head,mlp.gate", + "iters": 0, }, "qwen_mxfp4": { "scheme": "MXFP4", "fp_layers": "lm_head,mlp.gate,self_attn", - "iters": 200, + "iters": 0, # TODO: set to 200 before merge }, } def quant_model(args): - fp_layers = "shared_experts,lm_head,mlp.gate" - if args.skip_attn: - fp_layers += ",self_attn" - # fp_layers += ",layers.0" - logger.info(f"Using fp_layers: {fp_layers}") + config = topologies_config[args.t] + + logger.info(f"Using fp_layers: {config['fp_layers']}") autoround = AutoRound( model=args.model, - scheme=args.scheme, + scheme=config["scheme"], enable_torch_compile=args.enable_torch_compile, - iters=args.iters, - fp_layers=fp_layers, + iters=config['iters'], + fp_layers=config["fp_layers"], ) logger.info(f"Save quantized model to {args.output_dir}") format_type = "auto_round" if args.use_autoround_format else "llm_compressor" autoround.quantize_and_save( format=format_type, - output_dir=args.output_dir, + output_dir=f"quantized_model_{args.t}", ) @@ -81,6 +82,13 @@ def quant_model(args): default="MXFP4", help="Quantization scheme to use. Available options: " + ", ".join(AVAILABLE_SCHEMES.keys()), ) + parser.add_argument( + "-t", + type=str, + choices=topologies_config.keys(), + default="qwen_mxfp4", + help="Quantization scheme to use. Available options: " + ", ".join(topologies_config.keys()), + ) parser.add_argument( "--enable_torch_compile", diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 9bb63ad03..06ee50ee2 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -5,6 +5,8 @@ export VLLM_ENABLE_V1_MULTIPROCESSING=0 model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" +model_path="quantized_model_qwen_mxfp8" +model_path="quantized_model_qwen_mxfp4" # model_path="quantized_models/Qwen3-235B-A22B-MXFP4" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index b707515fb..886b76939 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -2,15 +2,17 @@ export AR_LOG_LEVEL=TRACE model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" -model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" -model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" +# model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" +# model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" base_name=$(basename ${model}) scheme="MXFP4" # scheme="MXFP8" qmodel_dir="quantized_models/" mkdir -p ${qmodel_dir} output_dir="${qmodel_dir}/${base_name}-${scheme}" -python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format +# python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format +# python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format +python quantize.py --model $model -t qwen_mxfp4 --use_autoround_format # model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" # scheme="MXFP4" From 284c41e1ce78c3338dde01136a0a7cdbd7d31bec Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 23:09:51 -0800 Subject: [PATCH 12/52] add mxfp4 moe Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/linear_impl_mxfp8.py | 2 ++ auto_round_extension/vllm_ext/mxfp8_qdq_utils.py | 4 ++-- auto_round_extension/vllm_ext/quant_method_moe.py | 5 +++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp8.py b/auto_round_extension/vllm_ext/linear_impl_mxfp8.py index 0c6e1998a..c7cc4cd60 100644 --- a/auto_round_extension/vllm_ext/linear_impl_mxfp8.py +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp8.py @@ -113,6 +113,7 @@ def apply_weights( weight_fp8=weight.data, scale_e8m0=weight_scale.data, block_size=self.group_size, + target_dtype=x.dtype, ) dequnat_weight = dequnat_weight.to(x.dtype) # if not envs.VLLM_AR_MXFP8_DISABLE_INPUT_QDQ: @@ -122,6 +123,7 @@ def apply_weights( weight_fp8=x_quant, scale_e8m0=x_scale, block_size=self.group_size, + target_dtype=x.dtype, ) x = dequant_x.to(x.dtype) diff --git a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py index 1250180b6..6b7ee60d9 100644 --- a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py @@ -41,7 +41,7 @@ def get_fp_scale(scale_e8m0): return s_fp -def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size): +def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size, target_dtype): scale_float = get_fp_scale(scale_e8m0) weight_bf16 = weight_fp8.to(torch.bfloat16) weight_original_shape = weight_bf16.shape @@ -49,7 +49,7 @@ def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size): scale_float = scale_float.reshape(-1, 1) dequant_weight = weight_bf16 * scale_float dequant_weight = dequant_weight.reshape(weight_original_shape) - return dequant_weight + return dequant_weight.to(target_dtype) def quant_mx_fp8(tensor): diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index fdc695259..49877e30b 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -56,6 +56,11 @@ def get_impl(scheme: QuantizationScheme): return AutoRoundMoEMethodMXFp4Impl(quant_config, layer.moe_config) + elif _is_mxfp8_w8a8(scheme): + from auto_round_extension.vllm_ext.moe_impl_mxfp8 import AutoRoundMoEMethodMXFp8Impl + + return AutoRoundMoEMethodMXFp8Impl(quant_config, layer.moe_config) + raise ValueError(f"Unsupported FusedMoe scheme: {scheme}") layer_scheme = get_scheme(quant_config, prefix) From 936ec4e656a7288ab75bb9d84d32888536759c28 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 23:42:15 -0800 Subject: [PATCH 13/52] fix skip layers Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/quant_method_linear.py | 2 +- inc_examples/quantize.py | 2 +- inc_examples/run_gen.sh | 3 ++- inc_examples/run_quant.sh | 6 ++++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index 6721b1d40..c13449594 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -58,7 +58,7 @@ def get_impl(scheme: QuantizationScheme): # TODO: use a more robust way to map layer names if prefix.endswith("gate_up_proj"): # update gate_up_proj to gate_proj, assume both gate and up share the same quantization scheme - prefix = prefix.replace("gate_up_proj", "gate_proj") + prefix = prefix.replace("gate_up_proj", "up_proj") if prefix.endswith("qkv_proj"): # update qkv_proj to q_proj, assume all qkv share the same quantization scheme prefix = prefix.replace("qkv_proj", "q_proj") diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index 1be7c2033..b237826c6 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -11,7 +11,7 @@ topologies_config = { "ds_mxfp8": { "scheme": "MXFP8", - "fp_layers": "lm_head,mlp.gate", + "fp_layers": "lm_head", "iters": 0, }, "ds_mxfp4": { diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 06ee50ee2..6134a7035 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -6,7 +6,8 @@ export VLLM_ENABLE_V1_MULTIPROCESSING=0 model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" model_path="quantized_model_qwen_mxfp8" -model_path="quantized_model_qwen_mxfp4" +model_path="quantized_model_ds_mxfp8" +# model_path="quantized_model_qwen_mxfp4" # model_path="quantized_models/Qwen3-235B-A22B-MXFP4" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index 886b76939..059b8c813 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -6,13 +6,15 @@ model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" # model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" base_name=$(basename ${model}) scheme="MXFP4" -# scheme="MXFP8" +scheme="MXFP8" qmodel_dir="quantized_models/" mkdir -p ${qmodel_dir} output_dir="${qmodel_dir}/${base_name}-${scheme}" # python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format +python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format +# python quantize.py --model $model -t qwen_mxfp4 --use_autoround_format # python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format -python quantize.py --model $model -t qwen_mxfp4 --use_autoround_format +# python quantize.py --model $model -t ds_mxfp8 --use_autoround_format # model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" # scheme="MXFP4" From 218f564565c255db5c16cee08339f4819962da5c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 10 Nov 2025 00:19:52 -0800 Subject: [PATCH 14/52] update example Signed-off-by: yiliu30 --- inc_examples/quantize.py | 2 +- inc_examples/run_gen.sh | 6 ++++-- inc_examples/run_quant.sh | 5 +++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index b237826c6..a6dad2962 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -16,7 +16,7 @@ }, "ds_mxfp4": { "scheme": "MXFP4", - "fp_layers": "lm_head,mlp.gate,self_attn", + "fp_layers": "lm_head,self_attn", "iters": 0, }, "qwen_mxfp8": { diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 6134a7035..74dd97459 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -6,7 +6,9 @@ export VLLM_ENABLE_V1_MULTIPROCESSING=0 model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" model_path="quantized_model_qwen_mxfp8" -model_path="quantized_model_ds_mxfp8" +# model_path="quantized_model_ds_mxfp8" +# model_path="quantized_model_ds_mxfp4" +# model_path="quantized_model_qwen_mxfp4" # model_path="quantized_model_qwen_mxfp4" # model_path="quantized_models/Qwen3-235B-A22B-MXFP4" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" @@ -15,10 +17,10 @@ model_path="quantized_model_ds_mxfp8" # /home/yiliu7/workspace/torchutils/examples # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ +# VLLM_LOGGING_LEVEL=DEBUG \ VLLM_AR_MXFP4_MODULAR_MOE=1 \ VLLM_ENABLE_STATIC_MOE=0 \ VLLM_USE_DEEP_GEMM=0 \ -VLLM_LOGGING_LEVEL=DEBUG \ VLLM_ENABLE_V1_MULTIPROCESSING=0 \ python generate.py \ --model ${model_path} \ diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index 059b8c813..a2f4d3488 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -11,8 +11,9 @@ qmodel_dir="quantized_models/" mkdir -p ${qmodel_dir} output_dir="${qmodel_dir}/${base_name}-${scheme}" # python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format -python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format -# python quantize.py --model $model -t qwen_mxfp4 --use_autoround_format +# python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format +python quantize.py --model /storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/ -t qwen_mxfp4 --use_autoround_format +python quantize.py --model /storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat -t ds_mxfp4 --use_autoround_format # python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format # python quantize.py --model $model -t ds_mxfp8 --use_autoround_format # model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" From 347d6800afdaf07fd09d6f405ad2848acfdcf255 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 10 Nov 2025 00:41:35 -0800 Subject: [PATCH 15/52] clean code Signed-off-by: yiliu30 --- .../vllm_ext/linear_impl_mxfp8.py | 41 +++++-------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp8.py b/auto_round_extension/vllm_ext/linear_impl_mxfp8.py index c7cc4cd60..cf1382054 100644 --- a/auto_round_extension/vllm_ext/linear_impl_mxfp8.py +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp8.py @@ -31,7 +31,6 @@ def __init__(self, quant_scheme): self.quant_scheme = quant_scheme self.strategy = "TENSOR_GROUP" self.out_dtype = torch.get_default_dtype() - self.is_static_input_scheme = False self.group_size = 32 @classmethod @@ -69,37 +68,19 @@ def create_weights( layer.register_parameter("weight", weight) # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == "TENSOR_GROUP": - # Per Group Weight Scale - weight_scale = GroupQuantScaleParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // self.group_size, - dtype=torch.uint8, # E8M0 for MXFP8 scale - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - else: - raise NotImplementedError(f"Strategy {self.strategy} is not supported for W8A8-MXFp8") - - # min requirement for fp8 kernels - # weight_scale[:] = torch.finfo(torch.float32).min - # weight_scale.fill_(torch.finfo(torch.float32).min) + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.uint8, # E8M0 for MXFP8 scale + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) - # INPUT SCALE - if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - input_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("input_scale", input_scale) - def apply_weights( self, layer: torch.nn.Module, From 7eb99749bd8939c2225687d15403ae2fd2d0ee9d Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 01:02:51 -0800 Subject: [PATCH 16/52] add mxfp4-mxfp8-moe Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/envs_ext.py | 1 + auto_round_extension/vllm_ext/fp4_utils.py | 2 +- .../vllm_ext/moe_impl_mxfp4.py | 126 ++++++++++++------ .../vllm_ext/mxfp4_qdq_utils.py | 1 + .../vllm_ext/quant_method_linear.py | 6 + .../vllm_ext/quant_method_moe.py | 1 + .../vllm_ext/torchao_patch.py | 1 + inc_examples/quantize.py | 2 +- inc_examples/run_gen.sh | 61 ++++++++- inc_examples/run_quant.sh | 16 ++- 10 files changed, 158 insertions(+), 59 deletions(-) diff --git a/auto_round_extension/vllm_ext/envs_ext.py b/auto_round_extension/vllm_ext/envs_ext.py index 845854cbc..e35ed41bc 100644 --- a/auto_round_extension/vllm_ext/envs_ext.py +++ b/auto_round_extension/vllm_ext/envs_ext.py @@ -22,6 +22,7 @@ # Define extra environment variables extra_environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"), + "VLLM_MXFP4_PRE_UNPACK_TO_FP8": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "0") in ("1", "true", "True"), "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"), "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"), "VLLM_AR_POST_PROCESS_GPTOSS": lambda: os.getenv("VLLM_AR_POST_PROCESS_GPTOSS", "0") in ("1", "true", "True"), diff --git a/auto_round_extension/vllm_ext/fp4_utils.py b/auto_round_extension/vllm_ext/fp4_utils.py index adb67d190..fb7b6e935 100644 --- a/auto_round_extension/vllm_ext/fp4_utils.py +++ b/auto_round_extension/vllm_ext/fp4_utils.py @@ -51,7 +51,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: indices = indices.reshape(-1) # Handle odd length by padding if necessary - assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}" + # assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}" # Reshape to pair consecutive elements indices = indices.reshape(-1, 2) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index 754840768..c16844e6b 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -160,7 +160,21 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo from vllm.model_executor.layers.fused_moe.config import ( ocp_mx_moe_quant_config, ) - + if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: + self.input_dtype = "mxfp8_e4m3" + self.weight_dtype = "mxfp8_e4m3" + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, + block_shape=None, + ) + self.input_dtype = "mxfp4" self.weight_dtype = "mxfp4" return ocp_mx_moe_quant_config( @@ -176,53 +190,60 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo ) return None + def _dequant_fp4_to_fp8(self, layer): + weight_name_lst = ["w13_weight", "w2_weight"] + + for weight_name_prefix in weight_name_lst: + weight_name = f"{weight_name_prefix}_packed" + weight = getattr(layer, weight_name) + weight_scale_name = f"{weight_name_prefix}_scale" + weight_scale = getattr(layer, weight_scale_name) + new_weight_name = f"{weight_name_prefix}_unpacked" + new_scale_name = weight_scale_name + num_experts, _, _ = weight.shape + unpacked_weight_lst = [] + scale_list = [] + for expert_index in range(num_experts): + weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( + data_lp=weight[expert_index], + scale_e8m0=weight_scale[expert_index], + ) + + unpacked_weight_lst.append(weight_fp8) + scale_list.append(scale_bf16) + unpacked_weight_fp8 = torch.stack(unpacked_weight_lst, dim=0) + scale_bf16 = torch.stack(scale_list, dim=0) + assert unpacked_weight_fp8.shape[0] == num_experts, ( + f"Expected {num_experts} unpacked weights, got " f"{unpacked_weight_fp8.shape[0]}" + ) + delattr(layer, weight_name) + delattr(layer, weight_scale_name) + layer.register_parameter( + new_weight_name, + torch.nn.Parameter( + unpacked_weight_fp8, + requires_grad=False, + ), + ) + layer.register_parameter( + new_scale_name, + torch.nn.Parameter( + scale_bf16, + requires_grad=False, + ), + ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.info(f"Processing weights after loading for layer: {layer._prefix}") if envs.VLLM_ENABLE_STATIC_MOE: if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: - weight_name_lst = ["w13_weight", "w2_weight"] - - for weight_name_prefix in weight_name_lst: - weight_name = f"{weight_name_prefix}_packed" - weight = getattr(layer, weight_name) - weight_scale_name = f"{weight_name_prefix}_scale" - weight_scale = getattr(layer, weight_scale_name) - new_weight_name = f"{weight_name_prefix}_unpacked" - new_scale_name = weight_scale_name - num_experts, _, _ = weight.shape - unpacked_weight_lst = [] - scale_list = [] - for expert_index in range(num_experts): - weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( - data_lp=weight[expert_index], - scale_e8m0=weight_scale[expert_index], - ) - - unpacked_weight_lst.append(weight_fp8) - scale_list.append(scale_bf16) - unpacked_weight_fp8 = torch.stack(unpacked_weight_lst, dim=0) - scale_bf16 = torch.stack(scale_list, dim=0) - assert unpacked_weight_fp8.shape[0] == num_experts, ( - f"Expected {num_experts} unpacked weights, got " f"{unpacked_weight_fp8.shape[0]}" - ) - delattr(layer, weight_name) - delattr(layer, weight_scale_name) - layer.register_parameter( - new_weight_name, - torch.nn.Parameter( - unpacked_weight_fp8, - requires_grad=False, - ), - ) - layer.register_parameter( - new_scale_name, - torch.nn.Parameter( - scale_bf16, - requires_grad=False, - ), - ) - + self._dequant_fp4_to_fp8(layer) + return elif envs.VLLM_AR_MXFP4_MODULAR_MOE: - + if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: + self._dequant_fp4_to_fp8(layer) + return + def revert_interleaved_bias(bias): """ Convert from blocked bias format to interleaved format. @@ -354,6 +375,23 @@ def apply( if envs.VLLM_AR_MXFP4_MODULAR_MOE: from vllm.model_executor.layers.fused_moe import fused_experts + if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: + w1 = layer.w13_weight_unpacked + w2 = layer.w2_weight_unpacked + out = fused_experts( + x, + w1, + w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) + return out if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: w1 = layer.w13_weight_unpacked w2 = layer.w2_weight_unpacked diff --git a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py index 197d757fa..214c98c8d 100644 --- a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py @@ -161,6 +161,7 @@ def dequant_mxfp4_to_fp8(data_lp, scale_e8m0): def mxfp4_fp8_weight_to_bf16(weight_fp8, scale_bf16): origin_shape = weight_fp8.shape weight_fp8 = weight_fp8.reshape(-1, 32) + scale_bf16 = scale_bf16.reshape(-1, 1) assert weight_fp8.shape[0] == scale_bf16.shape[0], f"shape mismatch: {weight_fp8.shape} vs {scale_bf16.shape}" dequant_weight_bf16 = weight_fp8.to(torch.bfloat16) * scale_bf16 dequant_weight_bf16 = dequant_weight_bf16.reshape(origin_shape) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index c13449594..0bc19c4f4 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -55,6 +55,12 @@ def get_impl(scheme: QuantizationScheme): raise ValueError(f"Unsupported Linear scheme: {scheme}, layer: {prefix}") + packed_modules_mapping = quant_config.packed_modules_mapping + for packed_name, child_names in packed_modules_mapping.items(): + if prefix.endswith(packed_name): + prefix = prefix.replace(packed_name, child_names[0]) + break + # TODO: use a more robust way to map layer names if prefix.endswith("gate_up_proj"): # update gate_up_proj to gate_proj, assume both gate and up share the same quantization scheme diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index 49877e30b..d4da8f1d2 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -65,6 +65,7 @@ def get_impl(scheme: QuantizationScheme): layer_scheme = get_scheme(quant_config, prefix) impl = get_impl(layer_scheme) + layer._prefix = prefix logger.debug("Apply %s to %s", impl.__class__.__name__, prefix) return impl diff --git a/auto_round_extension/vllm_ext/torchao_patch.py b/auto_round_extension/vllm_ext/torchao_patch.py index d5b70c6cf..e8c891bb4 100644 --- a/auto_round_extension/vllm_ext/torchao_patch.py +++ b/auto_round_extension/vllm_ext/torchao_patch.py @@ -121,6 +121,7 @@ def to_mx( torch.float, ), f"{data_hp.dtype} is not supported yet" # TODO(future PR): consider supporting padding + data_hp = data_hp.contiguous() assert data_hp.numel() % block_size == 0, "unsupported" assert data_hp.is_contiguous(), "unsupported" assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index a6dad2962..7a9c1c542 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -47,7 +47,7 @@ def quant_model(args): format_type = "auto_round" if args.use_autoround_format else "llm_compressor" autoround.quantize_and_save( format=format_type, - output_dir=f"quantized_model_{args.t}", + output_dir=f"/storage/yiliu7/quantized_model_{args.t}", ) diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index 74dd97459..c2f8c1a06 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -1,5 +1,5 @@ -export VLLM_LOGGING_LEVEL=DEBUG -export VLLM_ENABLE_V1_MULTIPROCESSING=0 +# export VLLM_LOGGING_LEVEL=DEBUG +# export VLLM_ENABLE_V1_MULTIPROCESSING=0 @@ -12,18 +12,65 @@ model_path="quantized_model_qwen_mxfp8" # model_path="quantized_model_qwen_mxfp4" # model_path="quantized_models/Qwen3-235B-A22B-MXFP4" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" - - +model_path="/storage/yiliu7/quantized_model_ds_mxfp8" +model_path="/storage/yiliu7/quantized_model_ds_mxfp4" # /home/yiliu7/workspace/torchutils/examples # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ # VLLM_LOGGING_LEVEL=DEBUG \ +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_AR_MXFP4_MODULAR_MOE=1 \ +# VLLM_ENABLE_STATIC_MOE=0 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +# python generate.py \ +# --model ${model_path} \ +# --tensor_parallel_size 8 \ +# --max-tokens 16 \ +# --max-num-seqs 32 \ +# --gpu_memory_utilization 0.9 \ +# --distributed_executor_backend mp +# # --tensor_parallel_size 4 + + +# VLLM_LOGGING_LEVEL=DEBUG \ +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_AR_MXFP4_MODULAR_MOE=0 \ +# VLLM_ENABLE_STATIC_MOE=1 \ +# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +# python generate.py \ +# --model ${model_path} \ +# --tensor_parallel_size 4 \ +# --max-tokens 16 \ +# --max-num-seqs 32 \ +# --gpu_memory_utilization 0.9 \ +# --distributed_executor_backend mp \ +# --enforce-eager +# # --tensor_parallel_size 4 + +# # --enforce-eager \ +# # --max-model-len 1024 \ +# VLLM_LOGGING_LEVEL=DEBUG \ +# model_path="/home/yiliu7/workspace/auto-round/inc_examples/quantized_model_ds_mxfp4" + VLLM_AR_MXFP4_MODULAR_MOE=1 \ +VLLM_ENABLE_AR_EXT=1 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ VLLM_ENABLE_STATIC_MOE=0 \ +VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ VLLM_USE_DEEP_GEMM=0 \ VLLM_ENABLE_V1_MULTIPROCESSING=0 \ -python generate.py \ + python generate.py \ --model ${model_path} \ + --tensor_parallel_size 8 \ --max-tokens 16 \ - # --enforce-eager \ - # --tensor_parallel_size 4 \ No newline at end of file + --max-num-seqs 2 \ + --gpu_memory_utilization 0.75 + # \ + # --enforce-eager + # --tensor_parallel_size 4 + + # --enforce-eager \ + # --max-model-len 1024 \ \ No newline at end of file diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index a2f4d3488..22b9c3fa7 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -1,9 +1,10 @@ export AR_LOG_LEVEL=TRACE -model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" -model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" -# model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" -# model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" +qwen_model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" +# ds_model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" +ds_model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" +# ds_model="/storage/yiliu7/unsloth/DeepSeek-R1-BF16" +qwen_model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" base_name=$(basename ${model}) scheme="MXFP4" scheme="MXFP8" @@ -12,8 +13,11 @@ mkdir -p ${qmodel_dir} output_dir="${qmodel_dir}/${base_name}-${scheme}" # python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format # python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format -python quantize.py --model /storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/ -t qwen_mxfp4 --use_autoround_format -python quantize.py --model /storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat -t ds_mxfp4 --use_autoround_format +# python quantize.py --model $qwen_model -t qwen_mxfp4 --use_autoround_format +# python quantize.py --model $qwen_model -t qwen_mxfp8 --use_autoround_format +# python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format +# python quantize.py --model $ds_model -t ds_mxfp8 --use_autoround_format +python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format # python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format # python quantize.py --model $model -t ds_mxfp8 --use_autoround_format # model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" From 493f2df2464f0a4d188889791961fe32436166c9 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 19:22:47 -0800 Subject: [PATCH 17/52] fix moe mxfp8 Signed-off-by: yiliu30 --- inc_examples/run_gen.sh | 8 ++++++-- inc_examples/run_quant.sh | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index c2f8c1a06..c8323a78e 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -14,6 +14,8 @@ model_path="quantized_model_qwen_mxfp8" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" model_path="/storage/yiliu7/quantized_model_ds_mxfp8" model_path="/storage/yiliu7/quantized_model_ds_mxfp4" +model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" +tp_size=4 # /home/yiliu7/workspace/torchutils/examples # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ @@ -64,10 +66,12 @@ VLLM_USE_DEEP_GEMM=0 \ VLLM_ENABLE_V1_MULTIPROCESSING=0 \ python generate.py \ --model ${model_path} \ - --tensor_parallel_size 8 \ + --tensor_parallel_size $tp_size \ --max-tokens 16 \ --max-num-seqs 2 \ - --gpu_memory_utilization 0.75 + --gpu_memory_utilization 0.75 \ + --no-enable-prefix-caching \ + --enable_expert_parallel # \ # --enforce-eager # --tensor_parallel_size 4 diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh index 22b9c3fa7..83a7dae38 100644 --- a/inc_examples/run_quant.sh +++ b/inc_examples/run_quant.sh @@ -13,11 +13,11 @@ mkdir -p ${qmodel_dir} output_dir="${qmodel_dir}/${base_name}-${scheme}" # python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format # python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format -# python quantize.py --model $qwen_model -t qwen_mxfp4 --use_autoround_format +python quantize.py --model $qwen_model -t qwen_mxfp4 --use_autoround_format # python quantize.py --model $qwen_model -t qwen_mxfp8 --use_autoround_format # python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format # python quantize.py --model $ds_model -t ds_mxfp8 --use_autoround_format -python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format +# python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format # python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format # python quantize.py --model $model -t ds_mxfp8 --use_autoround_format # model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" From fbc04aebcfb0fd0b546f85fd5fbbfe5644d921ff Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 21:27:55 -0800 Subject: [PATCH 18/52] fix Signed-off-by: yiliu30 --- inc_examples/run_eval.sh | 102 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 inc_examples/run_eval.sh diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh new file mode 100644 index 000000000..e7208bbbc --- /dev/null +++ b/inc_examples/run_eval.sh @@ -0,0 +1,102 @@ +#!/bin/bash +# Check if a model name is passed as an argument, otherwise use the default model path +if [ -z "$1" ]; then +# model_path="Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound" + # model_path="/storage/yiliu7/quantized_model_ds_mxfp8" + model_path="/storage/yiliu7/quantized_model_ds_mxfp4" + model_path="/storage/yiliu7/quantized_model_ds_mxfp4" + model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" + model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" +else + model_path="$1" +fi + +tp_size=4 +model_name=$(basename ${model_path}) +output_dir="${model_name}-tp${tp_size}-gsm8k-acc" +task_name="mmlu" + +echo "Evaluating model: ${model_path} on task: ${task_name}, output dir: ${output_dir}" +# VLLM_ATTENTION_BACKEND=TRITON_ATTN \ +mkdir -p ${output_dir} +# VLLM_ATTENTION_BACKEND=FLASHINFER \ + +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_AR_MXFP4_MODULAR_MOE=1 \ +# VLLM_ENABLE_STATIC_MOE=0 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_AR_MXFP4_MODULAR_MOE=1 \ +# VLLM_ENABLE_STATIC_MOE=0 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_LOGGING_LEVEL=DEBUG \ +# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_LOGGING_LEVEL=DEBUG \ +# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +# lm_eval --model vllm \ +# --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False" \ +# --tasks $task_name \ +# --batch_size 16 \ +# --limit 32 \ +# --log_samples \ +# --seed 42 \ +# --output_path ${output_dir} \ +# --show_config 2>&1 | tee ${output_dir}/log.txt + +# + + +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_AR_MXFP4_MODULAR_MOE=1 \ +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ +# VLLM_ENABLE_STATIC_MOE=0 \ +# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +# lm_eval --model vllm \ +# --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ +# --tasks $task_name \ +# --batch_size 16 \ +# --limit 256 \ +# --log_samples \ +# --seed 42 \ +# --output_path ${output_dir} \ +# --show_config 2>&1 | tee ${output_dir}/log.txt + + +# /storage/yiliu7/quantized_model_qwen_mxfp4 4x200 +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_AR_MXFP4_MODULAR_MOE=1 \ +# VLLM_ENABLE_AR_EXT=1 \ +# VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ +# VLLM_ENABLE_STATIC_MOE=0 \ +# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=1 \ +# VLLM_USE_DEEP_GEMM=0 \ +# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +# lm_eval --model vllm \ +# --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ +# --tasks $task_name \ +# --batch_size 16 \ +# --limit 256 \ +# --log_samples \ +# --seed 42 \ +# --output_path ${output_dir} \ +# --show_config 2>&1 | tee ${output_dir}/log.txt + +VLLM_ENABLE_AR_EXT=1 \ +VLLM_AR_MXFP4_MODULAR_MOE=1 \ +VLLM_ENABLE_AR_EXT=1 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ +VLLM_ENABLE_STATIC_MOE=0 \ +VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +VLLM_USE_DEEP_GEMM=0 \ +VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +lm_eval --model vllm \ + --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ + --tasks $task_name \ + --log_samples \ + --seed 42 \ + --output_path ${output_dir} \ + --show_config 2>&1 | tee ${output_dir}/log.txt + From bb4d90c249eab42b60459726cb003864e71f1e93 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 21:28:34 -0800 Subject: [PATCH 19/52] fix Signed-off-by: yiliu30 --- inc_examples/run_eval.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh index e7208bbbc..bc9d9d79b 100644 --- a/inc_examples/run_eval.sh +++ b/inc_examples/run_eval.sh @@ -14,7 +14,7 @@ fi tp_size=4 model_name=$(basename ${model_path}) output_dir="${model_name}-tp${tp_size}-gsm8k-acc" -task_name="mmlu" +task_name="gsm8k" echo "Evaluating model: ${model_path} on task: ${task_name}, output dir: ${output_dir}" # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ @@ -83,7 +83,8 @@ mkdir -p ${output_dir} # --seed 42 \ # --output_path ${output_dir} \ # --show_config 2>&1 | tee ${output_dir}/log.txt - + +# /storage/yiliu7/quantized_model_qwen_mxfp8 4x200 VLLM_ENABLE_AR_EXT=1 \ VLLM_AR_MXFP4_MODULAR_MOE=1 \ VLLM_ENABLE_AR_EXT=1 \ @@ -95,6 +96,8 @@ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ lm_eval --model vllm \ --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ --tasks $task_name \ + --batch_size 16 \ + --limit 256 \ --log_samples \ --seed 42 \ --output_path ${output_dir} \ From edd3e9e6b76f2066b5e6d8d7d170213df8810b4d Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 21:57:44 -0800 Subject: [PATCH 20/52] add readme Signed-off-by: yiliu30 --- inc_examples/README.md | 69 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 inc_examples/README.md diff --git a/inc_examples/README.md b/inc_examples/README.md new file mode 100644 index 000000000..ce8c93b91 --- /dev/null +++ b/inc_examples/README.md @@ -0,0 +1,69 @@ +### Quantize Model + +- MXFP8 +```bash +export QWEN_MODEL=Qwen/Qwen3-235B-A22B +export DS_MODEL=deepseek-ai/DeepSeek-R1 + +python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format +python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format + + +- MXFP4 +```bash +export QWEN_MODEL=Qwen/Qwen3-235B-A22B +export DS_MODEL=deepseek-ai/DeepSeek-R1 +python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format +python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format +``` + + +### Prompt Tests +- MXFP8 +```bash +export model_path=/path/to/quantized_model +tp_size=8 +VLLM_AR_MXFP4_MODULAR_MOE=0 \ +VLLM_ENABLE_AR_EXT=1 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ +VLLM_ENABLE_STATIC_MOE=0 \ +VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +VLLM_USE_DEEP_GEMM=0 \ +VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + python generate.py \ + --model ${model_path} \ + --tensor_parallel_size $tp_size \ + --max-tokens 16 \ + --max-num-seqs 4 \ + --gpu_memory_utilization 0.75 \ + --no-enable-prefix-caching \ + --enable_expert_parallel +``` + +- MXFP4 +```bash +export model_path=/path/to/quantized_model +tp_size=8 +VLLM_AR_MXFP4_MODULAR_MOE=1 \ +VLLM_ENABLE_AR_EXT=1 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ +VLLM_ENABLE_STATIC_MOE=0 \ +VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +VLLM_USE_DEEP_GEMM=0 \ +VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + python generate.py \ + --model ${model_path} \ + --tensor_parallel_size $tp_size \ + --max-tokens 16 \ + --max-num-seqs 4 \ + --gpu_memory_utilization 0.75 \ + --no-enable-prefix-caching \ + --enable_expert_parallel +``` + +### Evaluation Tests + +WIP + + + From 5f799b8d356fcc324b74d41412962b327cf547e5 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:06:25 -0800 Subject: [PATCH 21/52] fix Signed-off-by: yiliu30 --- inc_examples/README.md | 2 +- inc_examples/generate.py | 1 + inc_examples/quantize.py | 31 ++----------------------------- 3 files changed, 4 insertions(+), 30 deletions(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index ce8c93b91..80acc865b 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -7,7 +7,7 @@ export DS_MODEL=deepseek-ai/DeepSeek-R1 python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format - +``` - MXFP4 ```bash diff --git a/inc_examples/generate.py b/inc_examples/generate.py index de9781c60..6a255be69 100644 --- a/inc_examples/generate.py +++ b/inc_examples/generate.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copied from https://github.com/vllm-project/vllm/ try: from auto_round_extension.vllm_ext import apply as apply_auto_round_extension diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index 7a9c1c542..bf0a4fcc3 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -47,27 +47,13 @@ def quant_model(args): format_type = "auto_round" if args.use_autoround_format else "llm_compressor" autoround.quantize_and_save( format=format_type, - output_dir=f"/storage/yiliu7/quantized_model_{args.t}", + output_dir=f"{args.output_dir}/quantized_model_{args.t}", ) if __name__ == "__main__": import argparse - # import ar_schemes # Assuming `ar_schemes` is a module in your project - import auto_round.schemes as ar_schemes - - # Define available schemes - AVAILABLE_SCHEMES = { - "MXFP8": "MXFP8", - "FP8_STATIC": ar_schemes.FP8_STATIC, - "MXFP8_AR": ar_schemes.MXFP8, - "MXFP4_AR": ar_schemes.MXFP4, - "MXFP4": "MXFP4", - "W4A16": "W4A16", - "NVFP4": ar_schemes.NVFP4, - } - # Parse command-line arguments parser = argparse.ArgumentParser(description="Select a quantization scheme.") parser.add_argument( @@ -75,13 +61,6 @@ def quant_model(args): type=str, help="Path to the pre-trained model or model identifier from Hugging Face Hub.", ) - parser.add_argument( - "--scheme", - type=str, - choices=AVAILABLE_SCHEMES.keys(), - default="MXFP4", - help="Quantization scheme to use. Available options: " + ", ".join(AVAILABLE_SCHEMES.keys()), - ) parser.add_argument( "-t", type=str, @@ -112,19 +91,13 @@ def quant_model(args): default=0, help="Number of iterations for quantization.", ) - # output_dir can also be added as an argument if needed parser.add_argument( "--output_dir", type=str, - default="quantized_model", + default="./", help="Directory to save the quantized model.", ) args = parser.parse_args() - # Set the scheme based on user input - scheme = AVAILABLE_SCHEMES[args.scheme] - - # Print the selected scheme for confirmation - logger.info(f"Selected quantization scheme: {args.scheme}") quant_model(args) From 84f3dbe8fc81ae5fea9e5a648d19100d5a87adac Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:07:48 -0800 Subject: [PATCH 22/52] update Signed-off-by: yiliu30 --- inc_examples/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/inc_examples/README.md b/inc_examples/README.md index 80acc865b..0bbd58e4f 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -1,3 +1,11 @@ + +## Support Matrix + +| Model Family | MXFP4 | MXFP8 | +|-------------|-------|-------| +| Qwen/Qwen3-235B-A22B | ✅ | ✅ | +| deepseek-ai/DeepSeek-R1 | ✅ | ✅ | + ### Quantize Model - MXFP8 From c9dbac02f8aff34bd9ba2957d9a98a67d7c54432 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:20:55 -0800 Subject: [PATCH 23/52] update Signed-off-by: yiliu30 --- inc_examples/README.md | 54 +++++++++++------------------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index 0bbd58e4f..2a526c0b1 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -1,10 +1,10 @@ ## Support Matrix -| Model Family | MXFP4 | MXFP8 | -|-------------|-------|-------| -| Qwen/Qwen3-235B-A22B | ✅ | ✅ | -| deepseek-ai/DeepSeek-R1 | ✅ | ✅ | +| Model Family | MXFP4 | MXFP8 | +| ----------------------- | ----- | ----- | +| Qwen/Qwen3-235B-A22B | ✅ | ✅ | +| deepseek-ai/DeepSeek-R1 | ✅ | ✅ | ### Quantize Model @@ -27,48 +27,22 @@ python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format ### Prompt Tests -- MXFP8 + +Usage: ```bash -export model_path=/path/to/quantized_model -tp_size=8 -VLLM_AR_MXFP4_MODULAR_MOE=0 \ -VLLM_ENABLE_AR_EXT=1 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ -VLLM_ENABLE_STATIC_MOE=0 \ -VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ -VLLM_USE_DEEP_GEMM=0 \ -VLLM_ENABLE_V1_MULTIPROCESSING=1 \ - python generate.py \ - --model ${model_path} \ - --tensor_parallel_size $tp_size \ - --max-tokens 16 \ - --max-num-seqs 4 \ - --gpu_memory_utilization 0.75 \ - --no-enable-prefix-caching \ - --enable_expert_parallel +./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] ``` +- MXFP8 +```bash +bash ./run_generate.sh -s mxfp8 -m /path/to/qwen_mxfp8 -tp 4 +bash ./run_generate.sh -s mxfp8 -m /path/to/ds_mxfp8 -tp 8 +``` - MXFP4 ```bash -export model_path=/path/to/quantized_model -tp_size=8 -VLLM_AR_MXFP4_MODULAR_MOE=1 \ -VLLM_ENABLE_AR_EXT=1 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ -VLLM_ENABLE_STATIC_MOE=0 \ -VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ -VLLM_USE_DEEP_GEMM=0 \ -VLLM_ENABLE_V1_MULTIPROCESSING=1 \ - python generate.py \ - --model ${model_path} \ - --tensor_parallel_size $tp_size \ - --max-tokens 16 \ - --max-num-seqs 4 \ - --gpu_memory_utilization 0.75 \ - --no-enable-prefix-caching \ - --enable_expert_parallel +bash ./run_generate.sh -s mxfp4 -m /path/to/qwen_mxfp4 -tp 4 +bash ./run_generate.sh -s mxfp4 -m /path/to/ds_mxfp4 -tp 8 ``` - ### Evaluation Tests WIP From 3d21a74f63ad176ed17525cf8badc3ab5de75b65 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:21:17 -0800 Subject: [PATCH 24/52] add gene Signed-off-by: yiliu30 --- inc_examples/run_generate.sh | 117 +++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 inc_examples/run_generate.sh diff --git a/inc_examples/run_generate.sh b/inc_examples/run_generate.sh new file mode 100644 index 000000000..2ca2dee30 --- /dev/null +++ b/inc_examples/run_generate.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# Model Testing Script +# Usage: ./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] + +# Default values +QUANT_TYPE="mxfp8" +MODEL_PATH="/path/to/quantized_model" +TP_SIZE=8 + +# Function to display usage +usage() { + echo "Usage: $0 -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size]" + echo " -s: Quantization scheme (mxfp4 or mxfp8, default: mxfp8)" + echo " -m: Path to quantized model (required)" + echo " -tp: Tensor parallelism size (default: 8)" + echo "" + echo "Examples:" + echo " $0 -s mxfp4 -m /path/to/my/model -tp 4" + echo " $0 -m /path/to/my/model" + echo " $0 -s mxfp8 -m /path/to/my/model" +} + +# Parse command line arguments +while getopts "s:m:tp:h" opt; do + case $opt in + s) + QUANT_TYPE="$OPTARG" + ;; + m) + MODEL_PATH="$OPTARG" + ;; + tp) + TP_SIZE="$OPTARG" + ;; + h) + usage + exit 0 + ;; + \?) + echo "Invalid option: -$OPTARG" >&2 + usage + exit 1 + ;; + :) + echo "Option -$OPTARG requires an argument." >&2 + usage + exit 1 + ;; + esac +done + +# Validate quantization type +QUANT_TYPE_UPPER=$(echo "$QUANT_TYPE" | tr '[:lower:]' '[:upper:]') +if [[ "$QUANT_TYPE_UPPER" != "MXFP4" && "$QUANT_TYPE_UPPER" != "MXFP8" ]]; then + echo "Error: Quantization type must be mxfp4 or mxfp8" + usage + exit 1 +fi + +# Validate model path +if [[ "$MODEL_PATH" == "/path/to/quantized_model" ]]; then + echo "Error: Model path is required (-m)" + usage + exit 1 +fi + +if [[ ! -d "$MODEL_PATH" ]]; then + echo "Error: Model path '$MODEL_PATH' does not exist or is not a directory" + exit 1 +fi + +# Validate TP_SIZE is a number +if ! [[ "$TP_SIZE" =~ ^[0-9]+$ ]] || [ "$TP_SIZE" -lt 1 ]; then + echo "Error: Tensor parallelism size must be a positive integer" + exit 1 +fi + +echo "Running $QUANT_TYPE_UPPER test with:" +echo " Model: $MODEL_PATH" +echo " Tensor Parallelism: $TP_SIZE" +echo "" + +# Set environment variables based on quantization type +if [[ "$QUANT_TYPE_UPPER" == "MXFP4" ]]; then + export VLLM_AR_MXFP4_MODULAR_MOE=1 + export VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 + echo "Using MXFP4 configuration" +else + export VLLM_AR_MXFP4_MODULAR_MOE=0 + export VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 + echo "Using MXFP8 configuration" +fi + +# Common environment variables +export VLLM_ENABLE_AR_EXT=1 +export VLLM_ENABLE_STATIC_MOE=0 +export VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 +export VLLM_USE_DEEP_GEMM=0 +export VLLM_ENABLE_V1_MULTIPROCESSING=1 + +echo "Environment variables set:" +echo " VLLM_AR_MXFP4_MODULAR_MOE=$VLLM_AR_MXFP4_MODULAR_MOE" +echo " VLLM_MXFP4_PRE_UNPACK_TO_FP8=$VLLM_MXFP4_PRE_UNPACK_TO_FP8" +echo " VLLM_ENABLE_AR_EXT=$VLLM_ENABLE_AR_EXT" +echo "" + +# Run the model +echo "Starting model generation..." +python generate.py \ + --model "${MODEL_PATH}" \ + --tensor_parallel_size $TP_SIZE \ + --max-tokens 16 \ + --max-num-seqs 4 \ + --gpu_memory_utilization 0.75 \ + --no-enable-prefix-caching \ + --enable_expert_parallel \ No newline at end of file From 01665c90a21095e42eaec34e10d1816eb6f4a34a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:21:34 -0800 Subject: [PATCH 25/52] update Signed-off-by: yiliu30 --- inc_examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index 2a526c0b1..44271ccc2 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -30,7 +30,7 @@ python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format Usage: ```bash -./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] +bash ./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] ``` - MXFP8 From ebe9d7905780261fbdf148cce0645416c5c6cb5f Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:22:31 -0800 Subject: [PATCH 26/52] update Signed-off-by: yiliu30 --- inc_examples/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index 44271ccc2..aa5effe87 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -7,20 +7,20 @@ | deepseek-ai/DeepSeek-R1 | ✅ | ✅ | ### Quantize Model - -- MXFP8 +- Export model path ```bash export QWEN_MODEL=Qwen/Qwen3-235B-A22B export DS_MODEL=deepseek-ai/DeepSeek-R1 +``` +- MXFP8 +```bash python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format ``` - MXFP4 ```bash -export QWEN_MODEL=Qwen/Qwen3-235B-A22B -export DS_MODEL=deepseek-ai/DeepSeek-R1 python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format ``` From 7b986e3ad5831df215e7f85a764d7e423f74ee95 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:24:56 -0800 Subject: [PATCH 27/52] update Signed-off-by: yiliu30 --- inc_examples/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index aa5effe87..ed58d539e 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -30,18 +30,18 @@ python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format Usage: ```bash -bash ./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] +bash ./run_generate.sh -s [mxfp4|mxfp8] -tp [tensor_parallel_size] -m [model_path] ``` - MXFP8 ```bash -bash ./run_generate.sh -s mxfp8 -m /path/to/qwen_mxfp8 -tp 4 -bash ./run_generate.sh -s mxfp8 -m /path/to/ds_mxfp8 -tp 8 +bash ./run_generate.sh -s mxfp8 -tp 4 -m /path/to/qwen_mxfp8 +bash ./run_generate.sh -s mxfp8 -tp 8 -m /path/to/ds_mxfp8 ``` - MXFP4 ```bash -bash ./run_generate.sh -s mxfp4 -m /path/to/qwen_mxfp4 -tp 4 -bash ./run_generate.sh -s mxfp4 -m /path/to/ds_mxfp4 -tp 8 +bash ./run_generate.sh -s mxfp4 -tp 4 -m /path/to/qwen_mxfp +bash ./run_generate.sh -s mxfp4 -tp 8 -m /path/to/ds_mxfp4 ``` ### Evaluation Tests From efd3b1dce2ec81a3fefc1d0468d3c71bf3a1a748 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:58:33 -0800 Subject: [PATCH 28/52] fix Signed-off-by: yiliu30 --- inc_examples/quantize.py | 45 +++++++++++++++++++++++++++++++++++- inc_examples/run_generate.sh | 32 ++++++++++++------------- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index bf0a4fcc3..0a650f371 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -32,7 +32,7 @@ } -def quant_model(args): +def quant_model_ar(args): config = topologies_config[args.t] logger.info(f"Using fp_layers: {config['fp_layers']}") @@ -49,7 +49,50 @@ def quant_model(args): format=format_type, output_dir=f"{args.output_dir}/quantized_model_{args.t}", ) + + +def get_model_and_tokenizer(model_name): + # Load model and tokenizer + fp32_model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + ) + return fp32_model, tokenizer + +def quant_model(args): + from neural_compressor.torch.quantization import ( + AutoRoundConfig, + convert, + prepare, + ) + config = topologies_config[args.t] + export_format = "auto_round" if args.use_autoround_format else "llm_compressor" + output_dir = f"{args.output_dir}/quantized_model_{args.t}" + fp32_model, tokenizer = get_model_and_tokenizer(args.model) + quant_config = AutoRoundConfig( + tokenizer=tokenizer, + # nsamples=32, + # seqlen=10, + # iters=1, + # amp=False, + # scale_dtype="fp16", + scheme=config["scheme"], + enable_torch_compile=args.enable_torch_compile, + iters=config['iters'], + fp_layers=config["fp_layers"], + export_format=export_format, + output_dir=output_dir, + ) + # quantizer execute + model = prepare(model=fp32_model, quant_config=quant_config) + inc_model = convert(model) + logger.info(f"Quantized model saved to {output_dir}") if __name__ == "__main__": import argparse diff --git a/inc_examples/run_generate.sh b/inc_examples/run_generate.sh index 2ca2dee30..a1e6cc93d 100644 --- a/inc_examples/run_generate.sh +++ b/inc_examples/run_generate.sh @@ -1,4 +1,3 @@ -#!/bin/bash # Model Testing Script # Usage: ./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] @@ -22,34 +21,33 @@ usage() { } # Parse command line arguments -while getopts "s:m:tp:h" opt; do - case $opt in - s) - QUANT_TYPE="$OPTARG" +while [[ $# -gt 0 ]]; do + case $1 in + -s) + QUANT_TYPE="$2" + shift 2 ;; - m) - MODEL_PATH="$OPTARG" + -m) + MODEL_PATH="$2" + shift 2 ;; - tp) - TP_SIZE="$OPTARG" + -tp) + TP_SIZE="$2" + shift 2 ;; - h) + -h) usage exit 0 ;; - \?) - echo "Invalid option: -$OPTARG" >&2 - usage - exit 1 - ;; - :) - echo "Option -$OPTARG requires an argument." >&2 + *) + echo "Invalid option: $1" >&2 usage exit 1 ;; esac done + # Validate quantization type QUANT_TYPE_UPPER=$(echo "$QUANT_TYPE" | tr '[:lower:]' '[:upper:]') if [[ "$QUANT_TYPE_UPPER" != "MXFP4" && "$QUANT_TYPE_UPPER" != "MXFP8" ]]; then From e5044b40375a876352b864b5d05e4bb6722089e1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 22:59:00 -0800 Subject: [PATCH 29/52] format Signed-off-by: yiliu30 --- inc_examples/quantize.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index 0a650f371..2ae1c2a41 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -3,11 +3,11 @@ import transformers import logging from auto_round import AutoRound + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) - topologies_config = { "ds_mxfp8": { "scheme": "MXFP8", @@ -27,20 +27,20 @@ "qwen_mxfp4": { "scheme": "MXFP4", "fp_layers": "lm_head,mlp.gate,self_attn", - "iters": 0, # TODO: set to 200 before merge + "iters": 0, # TODO: set to 200 before merge }, } def quant_model_ar(args): config = topologies_config[args.t] - + logger.info(f"Using fp_layers: {config['fp_layers']}") autoround = AutoRound( model=args.model, scheme=config["scheme"], enable_torch_compile=args.enable_torch_compile, - iters=config['iters'], + iters=config["iters"], fp_layers=config["fp_layers"], ) logger.info(f"Save quantized model to {args.output_dir}") @@ -49,7 +49,7 @@ def quant_model_ar(args): format=format_type, output_dir=f"{args.output_dir}/quantized_model_{args.t}", ) - + def get_model_and_tokenizer(model_name): # Load model and tokenizer @@ -63,13 +63,15 @@ def get_model_and_tokenizer(model_name): trust_remote_code=True, ) return fp32_model, tokenizer - + + def quant_model(args): from neural_compressor.torch.quantization import ( AutoRoundConfig, convert, prepare, ) + config = topologies_config[args.t] export_format = "auto_round" if args.use_autoround_format else "llm_compressor" output_dir = f"{args.output_dir}/quantized_model_{args.t}" @@ -83,7 +85,7 @@ def quant_model(args): # scale_dtype="fp16", scheme=config["scheme"], enable_torch_compile=args.enable_torch_compile, - iters=config['iters'], + iters=config["iters"], fp_layers=config["fp_layers"], export_format=export_format, output_dir=output_dir, @@ -94,6 +96,7 @@ def quant_model(args): inc_model = convert(model) logger.info(f"Quantized model saved to {output_dir}") + if __name__ == "__main__": import argparse From 68424c52030a4aa63f13fd0e1c66d0cc9c38cca0 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 11 Nov 2025 23:00:44 -0800 Subject: [PATCH 30/52] update Signed-off-by: yiliu30 --- inc_examples/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index ed58d539e..5a0adf8eb 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -15,14 +15,14 @@ export DS_MODEL=deepseek-ai/DeepSeek-R1 - MXFP8 ```bash -python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format -python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format +python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format --output_dir ./qmodels +python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format ----output_dir ./qmodels ``` - MXFP4 ```bash -python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format -python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format +python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format --output_dir ./qmodels +python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format --output_dir ./qmodels ``` From 55a4e5250aaab3093a66aa5d1d1ca3d8d32bf8e8 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 12 Nov 2025 03:09:10 -0800 Subject: [PATCH 31/52] update example Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/mxfp4_qdq_utils.py | 2 +- inc_examples/run_eval.sh | 8 ++++---- inc_examples/run_gen.sh | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py index 214c98c8d..300cbf872 100644 --- a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py @@ -157,8 +157,8 @@ def dequant_mxfp4_to_fp8(data_lp, scale_e8m0): ) return data_fp8, scale_float - def mxfp4_fp8_weight_to_bf16(weight_fp8, scale_bf16): + origin_shape = weight_fp8.shape weight_fp8 = weight_fp8.reshape(-1, 32) scale_bf16 = scale_bf16.reshape(-1, 1) diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh index bc9d9d79b..73fff9171 100644 --- a/inc_examples/run_eval.sh +++ b/inc_examples/run_eval.sh @@ -6,7 +6,7 @@ if [ -z "$1" ]; then model_path="/storage/yiliu7/quantized_model_ds_mxfp4" model_path="/storage/yiliu7/quantized_model_ds_mxfp4" model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" - model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" + # model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" else model_path="$1" fi @@ -15,6 +15,7 @@ tp_size=4 model_name=$(basename ${model_path}) output_dir="${model_name}-tp${tp_size}-gsm8k-acc" task_name="gsm8k" +# task_name="mmlu" echo "Evaluating model: ${model_path} on task: ${task_name}, output dir: ${output_dir}" # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ @@ -88,7 +89,7 @@ mkdir -p ${output_dir} VLLM_ENABLE_AR_EXT=1 \ VLLM_AR_MXFP4_MODULAR_MOE=1 \ VLLM_ENABLE_AR_EXT=1 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ VLLM_ENABLE_STATIC_MOE=0 \ VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ VLLM_USE_DEEP_GEMM=0 \ @@ -96,8 +97,7 @@ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ lm_eval --model vllm \ --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ --tasks $task_name \ - --batch_size 16 \ - --limit 256 \ + --batch_size 512 \ --log_samples \ --seed 42 \ --output_path ${output_dir} \ diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh index c8323a78e..8562684cc 100644 --- a/inc_examples/run_gen.sh +++ b/inc_examples/run_gen.sh @@ -14,7 +14,7 @@ model_path="quantized_model_qwen_mxfp8" # model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" model_path="/storage/yiliu7/quantized_model_ds_mxfp8" model_path="/storage/yiliu7/quantized_model_ds_mxfp4" -model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" +# model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" tp_size=4 # /home/yiliu7/workspace/torchutils/examples @@ -59,16 +59,16 @@ tp_size=4 VLLM_AR_MXFP4_MODULAR_MOE=1 \ VLLM_ENABLE_AR_EXT=1 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ VLLM_ENABLE_STATIC_MOE=0 \ -VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +VLLM_MXFP4_PRE_UNPACK_WEIGHTS=1 \ VLLM_USE_DEEP_GEMM=0 \ VLLM_ENABLE_V1_MULTIPROCESSING=0 \ python generate.py \ --model ${model_path} \ --tensor_parallel_size $tp_size \ --max-tokens 16 \ - --max-num-seqs 2 \ + --max-num-seqs 32 \ --gpu_memory_utilization 0.75 \ --no-enable-prefix-caching \ --enable_expert_parallel From 59ff18def4dc2f013ef0951187fe8b57ac0c4756 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 12 Nov 2025 23:02:02 -0800 Subject: [PATCH 32/52] correct mxfp8 usage Signed-off-by: yiliu30 --- .../vllm_ext/moe_impl_mxfp4.py | 2 +- inc_examples/README.md | 3 +- inc_examples/quantize.py | 4 +- inc_examples/run_eval.sh | 38 +++++++++++-------- inc_examples/run_generate.sh | 2 +- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index c16844e6b..a5812805d 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -234,7 +234,7 @@ def _dequant_fp4_to_fp8(self, layer): ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - logger.info(f"Processing weights after loading for layer: {layer._prefix}") + logger.debug(f"Processing weights after loading for layer: {layer._prefix}") if envs.VLLM_ENABLE_STATIC_MOE: if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: self._dequant_fp4_to_fp8(layer) diff --git a/inc_examples/README.md b/inc_examples/README.md index 5a0adf8eb..5ad366e3e 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -10,6 +10,7 @@ - Export model path ```bash export QWEN_MODEL=Qwen/Qwen3-235B-A22B +export QWEN_MODEL=/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/ export DS_MODEL=deepseek-ai/DeepSeek-R1 ``` @@ -18,7 +19,7 @@ export DS_MODEL=deepseek-ai/DeepSeek-R1 python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format --output_dir ./qmodels python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format ----output_dir ./qmodels ``` - +/storage/yiliu7/meta-llama/Meta-Llama-3-8B-Instruct - MXFP4 ```bash python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format --output_dir ./qmodels diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py index 2ae1c2a41..716dfa464 100644 --- a/inc_examples/quantize.py +++ b/inc_examples/quantize.py @@ -32,7 +32,7 @@ } -def quant_model_ar(args): +def quant_model(args): config = topologies_config[args.t] logger.info(f"Using fp_layers: {config['fp_layers']}") @@ -65,7 +65,7 @@ def get_model_and_tokenizer(model_name): return fp32_model, tokenizer -def quant_model(args): +def quant_model_(args): from neural_compressor.torch.quantization import ( AutoRoundConfig, convert, diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh index 73fff9171..e61170cf4 100644 --- a/inc_examples/run_eval.sh +++ b/inc_examples/run_eval.sh @@ -1,3 +1,5 @@ + + #!/bin/bash # Check if a model name is passed as an argument, otherwise use the default model path if [ -z "$1" ]; then @@ -5,8 +7,11 @@ if [ -z "$1" ]; then # model_path="/storage/yiliu7/quantized_model_ds_mxfp8" model_path="/storage/yiliu7/quantized_model_ds_mxfp4" model_path="/storage/yiliu7/quantized_model_ds_mxfp4" - model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" - # model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" + model_path="/storage/yiliu7/quantized_model_ds_mxfp8" + # model_path="qmodels/quantized_model_ds_mxfp8" + model_path="./small-qmodels/quantized_model_qwen_mxfp8/" + # model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" + model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" else model_path="$1" fi @@ -14,8 +19,11 @@ fi tp_size=4 model_name=$(basename ${model_path}) output_dir="${model_name}-tp${tp_size}-gsm8k-acc" -task_name="gsm8k" -# task_name="mmlu" +# task_name="gsm8k" +# batch_size=256 +batch_size=512 +task_name="piqa,hellaswag,mmlu" +# task_name="mmlu_high_school_biology" echo "Evaluating model: ${model_path} on task: ${task_name}, output dir: ${output_dir}" # VLLM_ATTENTION_BACKEND=TRITON_ATTN \ @@ -65,15 +73,14 @@ mkdir -p ${output_dir} # --output_path ${output_dir} \ # --show_config 2>&1 | tee ${output_dir}/log.txt - +# -MXFP4 Evaluation # /storage/yiliu7/quantized_model_qwen_mxfp4 4x200 -# VLLM_ENABLE_AR_EXT=1 \ # VLLM_AR_MXFP4_MODULAR_MOE=1 \ -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ +# VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ # VLLM_ENABLE_STATIC_MOE=0 \ -# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=1 \ +# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ # VLLM_USE_DEEP_GEMM=0 \ +# VLLM_ENABLE_AR_EXT=1 \ # VLLM_ENABLE_V1_MULTIPROCESSING=1 \ # lm_eval --model vllm \ # --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ @@ -85,19 +92,20 @@ mkdir -p ${output_dir} # --output_path ${output_dir} \ # --show_config 2>&1 | tee ${output_dir}/log.txt +# -MXFP8 Evaluation +# !!! Please set below knobs strictly for MXFP8 model evaluation !!! # /storage/yiliu7/quantized_model_qwen_mxfp8 4x200 VLLM_ENABLE_AR_EXT=1 \ -VLLM_AR_MXFP4_MODULAR_MOE=1 \ -VLLM_ENABLE_AR_EXT=1 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ -VLLM_ENABLE_STATIC_MOE=0 \ +VLLM_AR_MXFP4_MODULAR_MOE=0 \ VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ +VLLM_ENABLE_STATIC_MOE=0 \ VLLM_USE_DEEP_GEMM=0 \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \ lm_eval --model vllm \ - --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ + --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False" \ --tasks $task_name \ - --batch_size 512 \ + --batch_size $batch_size \ --log_samples \ --seed 42 \ --output_path ${output_dir} \ diff --git a/inc_examples/run_generate.sh b/inc_examples/run_generate.sh index a1e6cc93d..a3b869a69 100644 --- a/inc_examples/run_generate.sh +++ b/inc_examples/run_generate.sh @@ -95,7 +95,7 @@ export VLLM_ENABLE_AR_EXT=1 export VLLM_ENABLE_STATIC_MOE=0 export VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 export VLLM_USE_DEEP_GEMM=0 -export VLLM_ENABLE_V1_MULTIPROCESSING=1 +export VLLM_ENABLE_V1_MULTIPROCESSING=0 echo "Environment variables set:" echo " VLLM_AR_MXFP4_MODULAR_MOE=$VLLM_AR_MXFP4_MODULAR_MOE" From b8961e122bd0d877b944283f6334da359af503d1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Nov 2025 16:21:57 -0800 Subject: [PATCH 33/52] update example Signed-off-by: yiliu30 --- inc_examples/run_eval.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh index e61170cf4..4bcc101bc 100644 --- a/inc_examples/run_eval.sh +++ b/inc_examples/run_eval.sh @@ -9,14 +9,14 @@ if [ -z "$1" ]; then model_path="/storage/yiliu7/quantized_model_ds_mxfp4" model_path="/storage/yiliu7/quantized_model_ds_mxfp8" # model_path="qmodels/quantized_model_ds_mxfp8" - model_path="./small-qmodels/quantized_model_qwen_mxfp8/" + # model_path="./small-qmodels/quantized_model_qwen_mxfp8/" # model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" - model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" + # model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" else model_path="$1" fi -tp_size=4 +tp_size=8 model_name=$(basename ${model_path}) output_dir="${model_name}-tp${tp_size}-gsm8k-acc" # task_name="gsm8k" From 3d89bb3d92d4ec7145c2fd00bfe9b27a730e255c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Nov 2025 17:28:14 -0800 Subject: [PATCH 34/52] clean Signed-off-by: yiliu30 --- inc_examples/README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/inc_examples/README.md b/inc_examples/README.md index 5ad366e3e..8f297f000 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -10,7 +10,6 @@ - Export model path ```bash export QWEN_MODEL=Qwen/Qwen3-235B-A22B -export QWEN_MODEL=/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/ export DS_MODEL=deepseek-ai/DeepSeek-R1 ``` @@ -19,7 +18,7 @@ export DS_MODEL=deepseek-ai/DeepSeek-R1 python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format --output_dir ./qmodels python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format ----output_dir ./qmodels ``` -/storage/yiliu7/meta-llama/Meta-Llama-3-8B-Instruct + - MXFP4 ```bash python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format --output_dir ./qmodels @@ -42,7 +41,7 @@ bash ./run_generate.sh -s mxfp8 -tp 8 -m /path/to/ds_mxfp8 - MXFP4 ```bash bash ./run_generate.sh -s mxfp4 -tp 4 -m /path/to/qwen_mxfp -bash ./run_generate.sh -s mxfp4 -tp 8 -m /path/to/ds_mxfp4 +bash ./run_generate.sh -s mxfp4 -tp 8 -m /path/to/ds_mxfp4 ``` ### Evaluation Tests From 6acd7ead4ecbff75d05b6cdfaf10bdd41f97a2d0 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Nov 2025 21:46:35 -0800 Subject: [PATCH 35/52] add eval cmd Signed-off-by: yiliu30 --- inc_examples/README.md | 23 ++++++- inc_examples/run_eval.sh | 42 ------------ inc_examples/run_evaluation.sh | 117 +++++++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 44 deletions(-) create mode 100644 inc_examples/run_evaluation.sh diff --git a/inc_examples/README.md b/inc_examples/README.md index 8f297f000..715c2ad79 100644 --- a/inc_examples/README.md +++ b/inc_examples/README.md @@ -43,9 +43,28 @@ bash ./run_generate.sh -s mxfp8 -tp 8 -m /path/to/ds_mxfp8 bash ./run_generate.sh -s mxfp4 -tp 4 -m /path/to/qwen_mxfp bash ./run_generate.sh -s mxfp4 -tp 8 -m /path/to/ds_mxfp4 ``` -### Evaluation Tests +### Evaluation + + +Usage: +```bash +bash run_evaluation.sh -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size] +``` +```bash +bash run_evaluation.sh -s mxfp8 -t piqa,hellaswag,mmlu -tp 4 -b 512 -m /path/to/qwen_mxfp8 +bash run_evaluation.sh -s mxfp8 -t gsm8k -tp 4 -b 256 -m /path/to/qwen_mxfp8 +bash run_evaluation.sh -s mxfp8 -t piqa,hellaswag,mmlu -tp 8 -b 512 -m /path/to/ds_mxfp8 +bash run_evaluation.sh -s mxfp8 -t gsm8k -tp 8 -b 256 -m /path/to/ds_mxfp8 + +``` +- MXFP4 +```bash +bash run_evaluation.sh -s mxfp4 -t piqa,hellaswag,mmlu -tp 4 -b 512 -m /path/to/qwen_mxfp4 +bash run_evaluation.sh -s mxfp4 -t gsm8k -tp 4 -b 256 -m /path/to/qwen_mxfp4 +bash run_evaluation.sh -s mxfp4 -t piqa,hellaswag,mmlu -tp 8 -b 512 -m /path/to/ds_mxfp4 +bash run_evaluation.sh -s mxfp4 -t gsm8k -tp 8 -b 256 -m /path/to/ds_mxfp4 +``` -WIP diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh index 4bcc101bc..a8fea25d2 100644 --- a/inc_examples/run_eval.sh +++ b/inc_examples/run_eval.sh @@ -30,49 +30,7 @@ echo "Evaluating model: ${model_path} on task: ${task_name}, output dir: ${outpu mkdir -p ${output_dir} # VLLM_ATTENTION_BACKEND=FLASHINFER \ -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_AR_MXFP4_MODULAR_MOE=1 \ -# VLLM_ENABLE_STATIC_MOE=0 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_AR_MXFP4_MODULAR_MOE=1 \ -# VLLM_ENABLE_STATIC_MOE=0 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_LOGGING_LEVEL=DEBUG \ -# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_LOGGING_LEVEL=DEBUG \ -# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -# lm_eval --model vllm \ -# --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False" \ -# --tasks $task_name \ -# --batch_size 16 \ -# --limit 32 \ -# --log_samples \ -# --seed 42 \ -# --output_path ${output_dir} \ -# --show_config 2>&1 | tee ${output_dir}/log.txt - -# - -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_AR_MXFP4_MODULAR_MOE=1 \ -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ -# VLLM_ENABLE_STATIC_MOE=0 \ -# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -# lm_eval --model vllm \ -# --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ -# --tasks $task_name \ -# --batch_size 16 \ -# --limit 256 \ -# --log_samples \ -# --seed 42 \ -# --output_path ${output_dir} \ -# --show_config 2>&1 | tee ${output_dir}/log.txt - # -MXFP4 Evaluation # /storage/yiliu7/quantized_model_qwen_mxfp4 4x200 # VLLM_AR_MXFP4_MODULAR_MOE=1 \ diff --git a/inc_examples/run_evaluation.sh b/inc_examples/run_evaluation.sh new file mode 100644 index 000000000..b0cc3270d --- /dev/null +++ b/inc_examples/run_evaluation.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# Usage: ./run_evaluation.sh -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size] +# Default values +MODEL_PATH="" +SCHEME="mxfp8" +TASK_NAME="piqa,hellaswag,mmlu" +TP_SIZE=8 +BATCH_SIZE=512 + +# Function to display usage +usage() { + echo "Usage: $0 -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size]" + echo " -m: Path to the quantized model (required)" + echo " -s: Quantization scheme (mxfp4 or mxfp8, default: mxfp8)" + echo " -t: Task name(s) to evaluate (default: piqa,hellaswag,mmlu)" + echo " -tp: Tensor parallelism size (default: 8)" + echo " -b: Batch size (default: 512)" + echo "" + echo "Examples:" + echo " $0 -m /path/to/model -s mxfp4 -t gsm8k -tp 4 -b 256" + echo " $0 -m /path/to/model -s mxfp8 -t piqa,hellaswag -tp 8 -b 512" +} + +# Parse command-line arguments +while [[ $# -gt 0 ]]; do + case $1 in + -m) + MODEL_PATH="$2" + shift 2 + ;; + -s) + SCHEME="$2" + shift 2 + ;; + -t) + TASK_NAME="$2" + shift 2 + ;; + -tp) + TP_SIZE="$2" + shift 2 + ;; + -b) + BATCH_SIZE="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Invalid option: $1" >&2 + usage + exit 1 + ;; + esac +done + +# Validate required arguments +if [[ -z "$MODEL_PATH" ]]; then + echo "Error: Model path (-m) is required." + usage + exit 1 +fi + +# Extract model name and set output directory +MODEL_NAME=$(basename ${MODEL_PATH}) +OUTPUT_DIR="${MODEL_NAME}-tp${TP_SIZE}-eval" + +# Create output directory +mkdir -p ${OUTPUT_DIR} + +# Set environment variables based on the quantization scheme +if [[ "$SCHEME" == "mxfp4" ]]; then + VLLM_AR_MXFP4_MODULAR_MOE=1 + VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 + VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 + VLLM_ENABLE_STATIC_MOE=0 + VLLM_USE_DEEP_GEMM=0 + VLLM_ENABLE_AR_EXT=1 +elif [[ "$SCHEME" == "mxfp8" ]]; then + VLLM_AR_MXFP4_MODULAR_MOE=0 + VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 + VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 + VLLM_ENABLE_STATIC_MOE=0 + VLLM_USE_DEEP_GEMM=0 + VLLM_ENABLE_AR_EXT=1 +else + echo "Error: Invalid quantization scheme (-s). Must be 'mxfp4' or 'mxfp8'." + usage + exit 1 +fi + +# Run evaluation +echo "Evaluating model: ${MODEL_PATH}" +echo "Quantization scheme: ${SCHEME}" +echo "Tasks: ${TASK_NAME}" +echo "Tensor parallelism size: ${TP_SIZE}" +echo "Batch size: ${BATCH_SIZE}" +echo "Output directory: ${OUTPUT_DIR}" + +VLLM_ENABLE_AR_EXT=$VLLM_ENABLE_AR_EXT \ +VLLM_AR_MXFP4_MODULAR_MOE=$VLLM_AR_MXFP4_MODULAR_MOE \ +VLLM_MXFP4_PRE_UNPACK_TO_FP8=$VLLM_MXFP4_PRE_UNPACK_TO_FP8 \ +VLLM_MXFP4_PRE_UNPACK_WEIGHTS=$VLLM_MXFP4_PRE_UNPACK_WEIGHTS \ +VLLM_ENABLE_STATIC_MOE=$VLLM_ENABLE_STATIC_MOE \ +VLLM_USE_DEEP_GEMM=$VLLM_USE_DEEP_GEMM \ +VLLM_ENABLE_V1_MULTIPROCESSING=1 \ +lm_eval --model vllm \ + --model_args "pretrained=${MODEL_PATH},tensor_parallel_size=${TP_SIZE},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False" \ + --tasks $TASK_NAME \ + --batch_size $BATCH_SIZE \ + --log_samples \ + --seed 42 \ + --output_path ${OUTPUT_DIR} \ + --show_config 2>&1 | tee ${OUTPUT_DIR}/log.txt \ No newline at end of file From 6b720f062bd62f5d800259b5dad5accb08598520 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 13 Nov 2025 22:13:59 -0800 Subject: [PATCH 36/52] remove examples Signed-off-by: yiliu30 --- inc_examples/README.md | 70 ---------------- inc_examples/generate.py | 73 ---------------- inc_examples/quantize.py | 149 --------------------------------- inc_examples/run_eval.sh | 71 ---------------- inc_examples/run_evaluation.sh | 117 -------------------------- inc_examples/run_gen.sh | 80 ------------------ inc_examples/run_generate.sh | 115 ------------------------- inc_examples/run_quant.sh | 36 -------- 8 files changed, 711 deletions(-) delete mode 100644 inc_examples/README.md delete mode 100644 inc_examples/generate.py delete mode 100644 inc_examples/quantize.py delete mode 100644 inc_examples/run_eval.sh delete mode 100644 inc_examples/run_evaluation.sh delete mode 100644 inc_examples/run_gen.sh delete mode 100644 inc_examples/run_generate.sh delete mode 100644 inc_examples/run_quant.sh diff --git a/inc_examples/README.md b/inc_examples/README.md deleted file mode 100644 index 715c2ad79..000000000 --- a/inc_examples/README.md +++ /dev/null @@ -1,70 +0,0 @@ - -## Support Matrix - -| Model Family | MXFP4 | MXFP8 | -| ----------------------- | ----- | ----- | -| Qwen/Qwen3-235B-A22B | ✅ | ✅ | -| deepseek-ai/DeepSeek-R1 | ✅ | ✅ | - -### Quantize Model -- Export model path -```bash -export QWEN_MODEL=Qwen/Qwen3-235B-A22B -export DS_MODEL=deepseek-ai/DeepSeek-R1 -``` - -- MXFP8 -```bash -python quantize.py --model $QWEN_MODEL -t qwen_mxfp8 --use_autoround_format --output_dir ./qmodels -python quantize.py --model $DS_MODEL -t ds_mxfp8 --use_autoround_format ----output_dir ./qmodels -``` - -- MXFP4 -```bash -python quantize.py --model $QWEN_MODEL -t qwen_mxfp4 --use_autoround_format --output_dir ./qmodels -python quantize.py --model $DS_MODEL -t qwen_mxfp4 --use_autoround_format --output_dir ./qmodels -``` - - -### Prompt Tests - -Usage: -```bash -bash ./run_generate.sh -s [mxfp4|mxfp8] -tp [tensor_parallel_size] -m [model_path] -``` - -- MXFP8 -```bash -bash ./run_generate.sh -s mxfp8 -tp 4 -m /path/to/qwen_mxfp8 -bash ./run_generate.sh -s mxfp8 -tp 8 -m /path/to/ds_mxfp8 -``` -- MXFP4 -```bash -bash ./run_generate.sh -s mxfp4 -tp 4 -m /path/to/qwen_mxfp -bash ./run_generate.sh -s mxfp4 -tp 8 -m /path/to/ds_mxfp4 -``` -### Evaluation - - -Usage: -```bash -bash run_evaluation.sh -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size] -``` -```bash -bash run_evaluation.sh -s mxfp8 -t piqa,hellaswag,mmlu -tp 4 -b 512 -m /path/to/qwen_mxfp8 -bash run_evaluation.sh -s mxfp8 -t gsm8k -tp 4 -b 256 -m /path/to/qwen_mxfp8 -bash run_evaluation.sh -s mxfp8 -t piqa,hellaswag,mmlu -tp 8 -b 512 -m /path/to/ds_mxfp8 -bash run_evaluation.sh -s mxfp8 -t gsm8k -tp 8 -b 256 -m /path/to/ds_mxfp8 - -``` -- MXFP4 -```bash -bash run_evaluation.sh -s mxfp4 -t piqa,hellaswag,mmlu -tp 4 -b 512 -m /path/to/qwen_mxfp4 -bash run_evaluation.sh -s mxfp4 -t gsm8k -tp 4 -b 256 -m /path/to/qwen_mxfp4 -bash run_evaluation.sh -s mxfp4 -t piqa,hellaswag,mmlu -tp 8 -b 512 -m /path/to/ds_mxfp4 -bash run_evaluation.sh -s mxfp4 -t gsm8k -tp 8 -b 256 -m /path/to/ds_mxfp4 -``` - - - - diff --git a/inc_examples/generate.py b/inc_examples/generate.py deleted file mode 100644 index 6a255be69..000000000 --- a/inc_examples/generate.py +++ /dev/null @@ -1,73 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Copied from https://github.com/vllm-project/vllm/ - -try: - from auto_round_extension.vllm_ext import apply as apply_auto_round_extension - apply_auto_round_extension() -except ImportError: - print("auto_round_extension.vllm_ext not found, proceeding without auto-round extension.") - -from vllm import LLM, EngineArgs -from vllm.utils.argparse_utils import FlexibleArgumentParser - - - -def create_parser(): - parser = FlexibleArgumentParser() - # Add engine args - EngineArgs.add_cli_args(parser) - parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) - - return parser - - -def main(args: dict): - # Pop arguments not used by LLM - max_tokens = args.pop("max_tokens") - temperature = args.pop("temperature") - top_p = args.pop("top_p") - top_k = args.pop("top_k") - - # Create an LLM - llm = LLM(**args) - - # Create a sampling params object - sampling_params = llm.get_default_sampling_params() - if max_tokens is not None: - sampling_params.max_tokens = max_tokens - if temperature is not None: - sampling_params.temperature = temperature - if top_p is not None: - sampling_params.top_p = top_p - if top_k is not None: - sampling_params.top_k = top_k - - # Generate texts from the prompts. The output is a list of RequestOutput - # objects that contain the prompt, generated text, and other information. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - parser = create_parser() - args: dict = vars(parser.parse_args()) - main(args) diff --git a/inc_examples/quantize.py b/inc_examples/quantize.py deleted file mode 100644 index 716dfa464..000000000 --- a/inc_examples/quantize.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -import transformers -import logging -from auto_round import AutoRound - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -topologies_config = { - "ds_mxfp8": { - "scheme": "MXFP8", - "fp_layers": "lm_head", - "iters": 0, - }, - "ds_mxfp4": { - "scheme": "MXFP4", - "fp_layers": "lm_head,self_attn", - "iters": 0, - }, - "qwen_mxfp8": { - "scheme": "MXFP8", - "fp_layers": "lm_head,mlp.gate", - "iters": 0, - }, - "qwen_mxfp4": { - "scheme": "MXFP4", - "fp_layers": "lm_head,mlp.gate,self_attn", - "iters": 0, # TODO: set to 200 before merge - }, -} - - -def quant_model(args): - config = topologies_config[args.t] - - logger.info(f"Using fp_layers: {config['fp_layers']}") - autoround = AutoRound( - model=args.model, - scheme=config["scheme"], - enable_torch_compile=args.enable_torch_compile, - iters=config["iters"], - fp_layers=config["fp_layers"], - ) - logger.info(f"Save quantized model to {args.output_dir}") - format_type = "auto_round" if args.use_autoround_format else "llm_compressor" - autoround.quantize_and_save( - format=format_type, - output_dir=f"{args.output_dir}/quantized_model_{args.t}", - ) - - -def get_model_and_tokenizer(model_name): - # Load model and tokenizer - fp32_model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map="cpu", - trust_remote_code=True, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_name, - trust_remote_code=True, - ) - return fp32_model, tokenizer - - -def quant_model_(args): - from neural_compressor.torch.quantization import ( - AutoRoundConfig, - convert, - prepare, - ) - - config = topologies_config[args.t] - export_format = "auto_round" if args.use_autoround_format else "llm_compressor" - output_dir = f"{args.output_dir}/quantized_model_{args.t}" - fp32_model, tokenizer = get_model_and_tokenizer(args.model) - quant_config = AutoRoundConfig( - tokenizer=tokenizer, - # nsamples=32, - # seqlen=10, - # iters=1, - # amp=False, - # scale_dtype="fp16", - scheme=config["scheme"], - enable_torch_compile=args.enable_torch_compile, - iters=config["iters"], - fp_layers=config["fp_layers"], - export_format=export_format, - output_dir=output_dir, - ) - - # quantizer execute - model = prepare(model=fp32_model, quant_config=quant_config) - inc_model = convert(model) - logger.info(f"Quantized model saved to {output_dir}") - - -if __name__ == "__main__": - import argparse - - # Parse command-line arguments - parser = argparse.ArgumentParser(description="Select a quantization scheme.") - parser.add_argument( - "--model", - type=str, - help="Path to the pre-trained model or model identifier from Hugging Face Hub.", - ) - parser.add_argument( - "-t", - type=str, - choices=topologies_config.keys(), - default="qwen_mxfp4", - help="Quantization scheme to use. Available options: " + ", ".join(topologies_config.keys()), - ) - - parser.add_argument( - "--enable_torch_compile", - action="store_true", - help="Enable torch compile for the model.", - ) - parser.add_argument( - "--use_autoround_format", - action="store_true", - help="Use AutoRound format for saving the quantized model.", - ) - - parser.add_argument( - "--skip_attn", - action="store_true", - help="Skip quantize attention layers.", - ) - parser.add_argument( - "--iters", - type=int, - default=0, - help="Number of iterations for quantization.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="./", - help="Directory to save the quantized model.", - ) - - args = parser.parse_args() - - quant_model(args) diff --git a/inc_examples/run_eval.sh b/inc_examples/run_eval.sh deleted file mode 100644 index a8fea25d2..000000000 --- a/inc_examples/run_eval.sh +++ /dev/null @@ -1,71 +0,0 @@ - - -#!/bin/bash -# Check if a model name is passed as an argument, otherwise use the default model path -if [ -z "$1" ]; then -# model_path="Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound" - # model_path="/storage/yiliu7/quantized_model_ds_mxfp8" - model_path="/storage/yiliu7/quantized_model_ds_mxfp4" - model_path="/storage/yiliu7/quantized_model_ds_mxfp4" - model_path="/storage/yiliu7/quantized_model_ds_mxfp8" - # model_path="qmodels/quantized_model_ds_mxfp8" - # model_path="./small-qmodels/quantized_model_qwen_mxfp8/" - # model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" - # model_path="/storage/yiliu7/quantized_model_qwen_mxfp8" -else - model_path="$1" -fi - -tp_size=8 -model_name=$(basename ${model_path}) -output_dir="${model_name}-tp${tp_size}-gsm8k-acc" -# task_name="gsm8k" -# batch_size=256 -batch_size=512 -task_name="piqa,hellaswag,mmlu" -# task_name="mmlu_high_school_biology" - -echo "Evaluating model: ${model_path} on task: ${task_name}, output dir: ${output_dir}" -# VLLM_ATTENTION_BACKEND=TRITON_ATTN \ -mkdir -p ${output_dir} -# VLLM_ATTENTION_BACKEND=FLASHINFER \ - - -# -MXFP4 Evaluation -# /storage/yiliu7/quantized_model_qwen_mxfp4 4x200 -# VLLM_AR_MXFP4_MODULAR_MOE=1 \ -# VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 \ -# VLLM_ENABLE_STATIC_MOE=0 \ -# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -# lm_eval --model vllm \ -# --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False,enable_expert_parallel=True" \ -# --tasks $task_name \ -# --batch_size 16 \ -# --limit 256 \ -# --log_samples \ -# --seed 42 \ -# --output_path ${output_dir} \ -# --show_config 2>&1 | tee ${output_dir}/log.txt - -# -MXFP8 Evaluation -# !!! Please set below knobs strictly for MXFP8 model evaluation !!! -# /storage/yiliu7/quantized_model_qwen_mxfp8 4x200 -VLLM_ENABLE_AR_EXT=1 \ -VLLM_AR_MXFP4_MODULAR_MOE=0 \ -VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ -VLLM_ENABLE_STATIC_MOE=0 \ -VLLM_USE_DEEP_GEMM=0 \ -VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -lm_eval --model vllm \ - --model_args "pretrained=${model_path},tensor_parallel_size=${tp_size},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False" \ - --tasks $task_name \ - --batch_size $batch_size \ - --log_samples \ - --seed 42 \ - --output_path ${output_dir} \ - --show_config 2>&1 | tee ${output_dir}/log.txt - diff --git a/inc_examples/run_evaluation.sh b/inc_examples/run_evaluation.sh deleted file mode 100644 index b0cc3270d..000000000 --- a/inc_examples/run_evaluation.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/bin/bash - -# Usage: ./run_evaluation.sh -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size] -# Default values -MODEL_PATH="" -SCHEME="mxfp8" -TASK_NAME="piqa,hellaswag,mmlu" -TP_SIZE=8 -BATCH_SIZE=512 - -# Function to display usage -usage() { - echo "Usage: $0 -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size]" - echo " -m: Path to the quantized model (required)" - echo " -s: Quantization scheme (mxfp4 or mxfp8, default: mxfp8)" - echo " -t: Task name(s) to evaluate (default: piqa,hellaswag,mmlu)" - echo " -tp: Tensor parallelism size (default: 8)" - echo " -b: Batch size (default: 512)" - echo "" - echo "Examples:" - echo " $0 -m /path/to/model -s mxfp4 -t gsm8k -tp 4 -b 256" - echo " $0 -m /path/to/model -s mxfp8 -t piqa,hellaswag -tp 8 -b 512" -} - -# Parse command-line arguments -while [[ $# -gt 0 ]]; do - case $1 in - -m) - MODEL_PATH="$2" - shift 2 - ;; - -s) - SCHEME="$2" - shift 2 - ;; - -t) - TASK_NAME="$2" - shift 2 - ;; - -tp) - TP_SIZE="$2" - shift 2 - ;; - -b) - BATCH_SIZE="$2" - shift 2 - ;; - -h|--help) - usage - exit 0 - ;; - *) - echo "Invalid option: $1" >&2 - usage - exit 1 - ;; - esac -done - -# Validate required arguments -if [[ -z "$MODEL_PATH" ]]; then - echo "Error: Model path (-m) is required." - usage - exit 1 -fi - -# Extract model name and set output directory -MODEL_NAME=$(basename ${MODEL_PATH}) -OUTPUT_DIR="${MODEL_NAME}-tp${TP_SIZE}-eval" - -# Create output directory -mkdir -p ${OUTPUT_DIR} - -# Set environment variables based on the quantization scheme -if [[ "$SCHEME" == "mxfp4" ]]; then - VLLM_AR_MXFP4_MODULAR_MOE=1 - VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 - VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 - VLLM_ENABLE_STATIC_MOE=0 - VLLM_USE_DEEP_GEMM=0 - VLLM_ENABLE_AR_EXT=1 -elif [[ "$SCHEME" == "mxfp8" ]]; then - VLLM_AR_MXFP4_MODULAR_MOE=0 - VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 - VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 - VLLM_ENABLE_STATIC_MOE=0 - VLLM_USE_DEEP_GEMM=0 - VLLM_ENABLE_AR_EXT=1 -else - echo "Error: Invalid quantization scheme (-s). Must be 'mxfp4' or 'mxfp8'." - usage - exit 1 -fi - -# Run evaluation -echo "Evaluating model: ${MODEL_PATH}" -echo "Quantization scheme: ${SCHEME}" -echo "Tasks: ${TASK_NAME}" -echo "Tensor parallelism size: ${TP_SIZE}" -echo "Batch size: ${BATCH_SIZE}" -echo "Output directory: ${OUTPUT_DIR}" - -VLLM_ENABLE_AR_EXT=$VLLM_ENABLE_AR_EXT \ -VLLM_AR_MXFP4_MODULAR_MOE=$VLLM_AR_MXFP4_MODULAR_MOE \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=$VLLM_MXFP4_PRE_UNPACK_TO_FP8 \ -VLLM_MXFP4_PRE_UNPACK_WEIGHTS=$VLLM_MXFP4_PRE_UNPACK_WEIGHTS \ -VLLM_ENABLE_STATIC_MOE=$VLLM_ENABLE_STATIC_MOE \ -VLLM_USE_DEEP_GEMM=$VLLM_USE_DEEP_GEMM \ -VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -lm_eval --model vllm \ - --model_args "pretrained=${MODEL_PATH},tensor_parallel_size=${TP_SIZE},max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False" \ - --tasks $TASK_NAME \ - --batch_size $BATCH_SIZE \ - --log_samples \ - --seed 42 \ - --output_path ${OUTPUT_DIR} \ - --show_config 2>&1 | tee ${OUTPUT_DIR}/log.txt \ No newline at end of file diff --git a/inc_examples/run_gen.sh b/inc_examples/run_gen.sh deleted file mode 100644 index 8562684cc..000000000 --- a/inc_examples/run_gen.sh +++ /dev/null @@ -1,80 +0,0 @@ -# export VLLM_LOGGING_LEVEL=DEBUG -# export VLLM_ENABLE_V1_MULTIPROCESSING=0 - - - -model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4/" -model_path="quantized_models/DeepSeek-V2-Lite-Chat-MXFP4" -model_path="quantized_model_qwen_mxfp8" -# model_path="quantized_model_ds_mxfp8" -# model_path="quantized_model_ds_mxfp4" -# model_path="quantized_model_qwen_mxfp4" -# model_path="quantized_model_qwen_mxfp4" -# model_path="quantized_models/Qwen3-235B-A22B-MXFP4" -# model_path="quantized_models/Qwen3-30B-A3B-Base-MXFP4" -model_path="/storage/yiliu7/quantized_model_ds_mxfp8" -model_path="/storage/yiliu7/quantized_model_ds_mxfp4" -# model_path="/storage/yiliu7/quantized_model_qwen_mxfp4" -tp_size=4 -# /home/yiliu7/workspace/torchutils/examples - -# VLLM_ATTENTION_BACKEND=TRITON_ATTN \ -# VLLM_LOGGING_LEVEL=DEBUG \ -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_AR_MXFP4_MODULAR_MOE=1 \ -# VLLM_ENABLE_STATIC_MOE=0 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -# python generate.py \ -# --model ${model_path} \ -# --tensor_parallel_size 8 \ -# --max-tokens 16 \ -# --max-num-seqs 32 \ -# --gpu_memory_utilization 0.9 \ -# --distributed_executor_backend mp -# # --tensor_parallel_size 4 - - -# VLLM_LOGGING_LEVEL=DEBUG \ -# VLLM_ENABLE_AR_EXT=1 \ -# VLLM_AR_MXFP4_MODULAR_MOE=0 \ -# VLLM_ENABLE_STATIC_MOE=1 \ -# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \ -# VLLM_USE_DEEP_GEMM=0 \ -# VLLM_ENABLE_V1_MULTIPROCESSING=1 \ -# python generate.py \ -# --model ${model_path} \ -# --tensor_parallel_size 4 \ -# --max-tokens 16 \ -# --max-num-seqs 32 \ -# --gpu_memory_utilization 0.9 \ -# --distributed_executor_backend mp \ -# --enforce-eager -# # --tensor_parallel_size 4 - -# # --enforce-eager \ -# # --max-model-len 1024 \ -# VLLM_LOGGING_LEVEL=DEBUG \ -# model_path="/home/yiliu7/workspace/auto-round/inc_examples/quantized_model_ds_mxfp4" - -VLLM_AR_MXFP4_MODULAR_MOE=1 \ -VLLM_ENABLE_AR_EXT=1 \ -VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 \ -VLLM_ENABLE_STATIC_MOE=0 \ -VLLM_MXFP4_PRE_UNPACK_WEIGHTS=1 \ -VLLM_USE_DEEP_GEMM=0 \ -VLLM_ENABLE_V1_MULTIPROCESSING=0 \ - python generate.py \ - --model ${model_path} \ - --tensor_parallel_size $tp_size \ - --max-tokens 16 \ - --max-num-seqs 32 \ - --gpu_memory_utilization 0.75 \ - --no-enable-prefix-caching \ - --enable_expert_parallel - # \ - # --enforce-eager - # --tensor_parallel_size 4 - - # --enforce-eager \ - # --max-model-len 1024 \ \ No newline at end of file diff --git a/inc_examples/run_generate.sh b/inc_examples/run_generate.sh deleted file mode 100644 index a3b869a69..000000000 --- a/inc_examples/run_generate.sh +++ /dev/null @@ -1,115 +0,0 @@ - -# Model Testing Script -# Usage: ./run_generate.sh -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size] - -# Default values -QUANT_TYPE="mxfp8" -MODEL_PATH="/path/to/quantized_model" -TP_SIZE=8 - -# Function to display usage -usage() { - echo "Usage: $0 -s [mxfp4|mxfp8] -m [model_path] -tp [tensor_parallel_size]" - echo " -s: Quantization scheme (mxfp4 or mxfp8, default: mxfp8)" - echo " -m: Path to quantized model (required)" - echo " -tp: Tensor parallelism size (default: 8)" - echo "" - echo "Examples:" - echo " $0 -s mxfp4 -m /path/to/my/model -tp 4" - echo " $0 -m /path/to/my/model" - echo " $0 -s mxfp8 -m /path/to/my/model" -} - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case $1 in - -s) - QUANT_TYPE="$2" - shift 2 - ;; - -m) - MODEL_PATH="$2" - shift 2 - ;; - -tp) - TP_SIZE="$2" - shift 2 - ;; - -h) - usage - exit 0 - ;; - *) - echo "Invalid option: $1" >&2 - usage - exit 1 - ;; - esac -done - - -# Validate quantization type -QUANT_TYPE_UPPER=$(echo "$QUANT_TYPE" | tr '[:lower:]' '[:upper:]') -if [[ "$QUANT_TYPE_UPPER" != "MXFP4" && "$QUANT_TYPE_UPPER" != "MXFP8" ]]; then - echo "Error: Quantization type must be mxfp4 or mxfp8" - usage - exit 1 -fi - -# Validate model path -if [[ "$MODEL_PATH" == "/path/to/quantized_model" ]]; then - echo "Error: Model path is required (-m)" - usage - exit 1 -fi - -if [[ ! -d "$MODEL_PATH" ]]; then - echo "Error: Model path '$MODEL_PATH' does not exist or is not a directory" - exit 1 -fi - -# Validate TP_SIZE is a number -if ! [[ "$TP_SIZE" =~ ^[0-9]+$ ]] || [ "$TP_SIZE" -lt 1 ]; then - echo "Error: Tensor parallelism size must be a positive integer" - exit 1 -fi - -echo "Running $QUANT_TYPE_UPPER test with:" -echo " Model: $MODEL_PATH" -echo " Tensor Parallelism: $TP_SIZE" -echo "" - -# Set environment variables based on quantization type -if [[ "$QUANT_TYPE_UPPER" == "MXFP4" ]]; then - export VLLM_AR_MXFP4_MODULAR_MOE=1 - export VLLM_MXFP4_PRE_UNPACK_TO_FP8=1 - echo "Using MXFP4 configuration" -else - export VLLM_AR_MXFP4_MODULAR_MOE=0 - export VLLM_MXFP4_PRE_UNPACK_TO_FP8=0 - echo "Using MXFP8 configuration" -fi - -# Common environment variables -export VLLM_ENABLE_AR_EXT=1 -export VLLM_ENABLE_STATIC_MOE=0 -export VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 -export VLLM_USE_DEEP_GEMM=0 -export VLLM_ENABLE_V1_MULTIPROCESSING=0 - -echo "Environment variables set:" -echo " VLLM_AR_MXFP4_MODULAR_MOE=$VLLM_AR_MXFP4_MODULAR_MOE" -echo " VLLM_MXFP4_PRE_UNPACK_TO_FP8=$VLLM_MXFP4_PRE_UNPACK_TO_FP8" -echo " VLLM_ENABLE_AR_EXT=$VLLM_ENABLE_AR_EXT" -echo "" - -# Run the model -echo "Starting model generation..." -python generate.py \ - --model "${MODEL_PATH}" \ - --tensor_parallel_size $TP_SIZE \ - --max-tokens 16 \ - --max-num-seqs 4 \ - --gpu_memory_utilization 0.75 \ - --no-enable-prefix-caching \ - --enable_expert_parallel \ No newline at end of file diff --git a/inc_examples/run_quant.sh b/inc_examples/run_quant.sh deleted file mode 100644 index 83a7dae38..000000000 --- a/inc_examples/run_quant.sh +++ /dev/null @@ -1,36 +0,0 @@ - -export AR_LOG_LEVEL=TRACE -qwen_model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" -# ds_model="/storage/yiliu7/Qwen/Qwen3-30B-A3B-Base/" -ds_model="/storage/yiliu7/deepseek-ai/DeepSeek-V2-Lite-Chat" -# ds_model="/storage/yiliu7/unsloth/DeepSeek-R1-BF16" -qwen_model="/storage/yiliu7/Qwen/Qwen3-235B-A22B" -base_name=$(basename ${model}) -scheme="MXFP4" -scheme="MXFP8" -qmodel_dir="quantized_models/" -mkdir -p ${qmodel_dir} -output_dir="${qmodel_dir}/${base_name}-${scheme}" -# python quantize.py --model $model --scheme $scheme --output_dir $output_dir --skip_attn --use_autoround_format -# python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format -python quantize.py --model $qwen_model -t qwen_mxfp4 --use_autoround_format -# python quantize.py --model $qwen_model -t qwen_mxfp8 --use_autoround_format -# python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format -# python quantize.py --model $ds_model -t ds_mxfp8 --use_autoround_format -# python quantize.py --model $ds_model -t ds_mxfp4 --use_autoround_format -# python quantize.py --model $model -t qwen_mxfp8 --use_autoround_format -# python quantize.py --model $model -t ds_mxfp8 --use_autoround_format -# model_name="/storage/yiliu7/Qwen/Qwen3-A3B-Base" - -# scheme="MXFP4" - -# output_path="./" -# base_name=$(basename ${model_name}) -# CUDA_VISIBLE_DEVICES=$device \ -# python3 quantize.py \ -# --model ${model} \ -# --scheme ${scheme} \ -# --format llm_compressor \ -# --iters 0 \ -# --enable_torch_compile \ -# --output_dir ${output_path}/${base_name}-${scheme} \ No newline at end of file From f880a1f3c44543f9d25d23a32ac886edf31fd645 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Nov 2025 06:15:52 +0000 Subject: [PATCH 37/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/vllm_ext/__init__.py | 2 +- .../vllm_ext/linear_impl_mxfp4.py | 22 ++++++++++++++----- .../vllm_ext/moe_impl_mxfp4.py | 6 +++-- .../vllm_ext/mxfp4_qdq_utils.py | 3 ++- .../vllm_ext/quant_method_linear.py | 4 ++-- .../vllm_ext/quant_method_moe.py | 2 +- 6 files changed, 26 insertions(+), 13 deletions(-) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index bccfb4c4e..3e145334f 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -23,4 +23,4 @@ def apply(): from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig - from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables \ No newline at end of file + from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py index 45ead99f5..d90b73ea5 100644 --- a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py @@ -1,21 +1,32 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # SPDX-License-Identifier: Apache-2.0 from typing import Callable, Optional import torch -from torch.nn.parameter import Parameter - import vllm.envs as envs +from torch.nn.parameter import Parameter from vllm.logger import init_logger - - from vllm.model_executor.parameter import GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter +from vllm.platforms import current_platform from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( dequant_mxfp4_to_fp8, mxfp4_gemm_with_unpacked_weight, run_mxfp4_emulations, ) -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -121,4 +132,3 @@ def apply_weights( bias=bias, ) return out - diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index a5812805d..7fcf44ba0 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -160,6 +160,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo from vllm.model_executor.layers.fused_moe.config import ( ocp_mx_moe_quant_config, ) + if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: self.input_dtype = "mxfp8_e4m3" self.weight_dtype = "mxfp8_e4m3" @@ -174,7 +175,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo w2_bias=layer.w2_bias if self.has_bias else None, block_shape=None, ) - + self.input_dtype = "mxfp4" self.weight_dtype = "mxfp4" return ocp_mx_moe_quant_config( @@ -243,7 +244,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: self._dequant_fp4_to_fp8(layer) return - + def revert_interleaved_bias(bias): """ Convert from blocked bias format to interleaved format. @@ -294,6 +295,7 @@ def revert_interleaved_w1(w1): new_w1[:, ::2, :] = w1[:, : N // 2, :] new_w1[:, 1::2, :] = w1[:, N // 2 :, :] return new_w1 + if envs.VLLM_AR_POST_PROCESS_GPTOSS: w1 = revert_interleaved_w1(w1) diff --git a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py index 300cbf872..1916fbc62 100644 --- a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py @@ -157,8 +157,9 @@ def dequant_mxfp4_to_fp8(data_lp, scale_e8m0): ) return data_fp8, scale_float + def mxfp4_fp8_weight_to_bf16(weight_fp8, scale_bf16): - + origin_shape = weight_fp8.shape weight_fp8 = weight_fp8.reshape(-1, 32) scale_bf16 = scale_bf16.reshape(-1, 1) diff --git a/auto_round_extension/vllm_ext/quant_method_linear.py b/auto_round_extension/vllm_ext/quant_method_linear.py index 0bc19c4f4..eecc97902 100644 --- a/auto_round_extension/vllm_ext/quant_method_linear.py +++ b/auto_round_extension/vllm_ext/quant_method_linear.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, need_quantize, get_scheme +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, get_scheme, need_quantize logger = init_logger(__name__) @@ -60,7 +60,7 @@ def get_impl(scheme: QuantizationScheme): if prefix.endswith(packed_name): prefix = prefix.replace(packed_name, child_names[0]) break - + # TODO: use a more robust way to map layer names if prefix.endswith("gate_up_proj"): # update gate_up_proj to gate_proj, assume both gate and up share the same quantization scheme diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index d4da8f1d2..ef5c75291 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, need_quantize, get_scheme +from auto_round_extension.vllm_ext.utils import _is_mxfp4_w4a4, _is_mxfp8_w8a8, get_scheme, need_quantize logger = init_logger(__name__) From ed856fc0a7f20e3560707111e39f17230c536071 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 16 Nov 2025 22:59:23 -0800 Subject: [PATCH 38/52] fix mxfp4 Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/moe_impl_mxfp4.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index a5812805d..46d42097c 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -163,7 +163,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: self.input_dtype = "mxfp8_e4m3" self.weight_dtype = "mxfp8_e4m3" - return ocp_mx_moe_quant_config( + ocp_config = ocp_mx_moe_quant_config( quant_dtype=self.input_dtype, weight_dtype=self.weight_dtype, w1_scale=layer.w13_weight_scale, @@ -174,6 +174,17 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo w2_bias=layer.w2_bias if self.has_bias else None, block_shape=None, ) + from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( + OCP_MX_DTYPES, + OCP_MX_Scheme, + ) + ocp_config.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( + input_dtype="mxfp8_e4m3", + weight_dtype="mxfp8_e4m3", + original_dtype="mxfp4", + ) + logger.warning_once(f"Set OCP_MX_Scheme to {ocp_config.ocp_mx_scheme}") + return ocp_config self.input_dtype = "mxfp4" self.weight_dtype = "mxfp4" From 70ce2d092d1d2cbe1997a077c5c4bc0f222e2d74 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 16 Nov 2025 23:03:16 -0800 Subject: [PATCH 39/52] add readme Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 auto_round_extension/vllm_ext/README.md diff --git a/auto_round_extension/vllm_ext/README.md b/auto_round_extension/vllm_ext/README.md new file mode 100644 index 000000000..47fde5e81 --- /dev/null +++ b/auto_round_extension/vllm_ext/README.md @@ -0,0 +1,14 @@ +- vllm https://github.com/yiliu30/vllm-fork/tree/fused-moe-ar +``` +VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv +``` +- Allow python patches vLLM with vLLM-Ext +``` +cd auto-round/auto_round_extension/vllm_ext +source apply_ext.sh +``` + +- Enable vLLM-Ext +```bash +VLLM_ENABLE_AR_EXT=1 vllm serve ... +``` \ No newline at end of file From 919e9540c307385fbf1e266d9ba93a97f5f43d80 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 16 Nov 2025 23:05:04 -0800 Subject: [PATCH 40/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/auto_round_extension/vllm_ext/README.md b/auto_round_extension/vllm_ext/README.md index 47fde5e81..790d7796b 100644 --- a/auto_round_extension/vllm_ext/README.md +++ b/auto_round_extension/vllm_ext/README.md @@ -1,14 +1,16 @@ -- vllm https://github.com/yiliu30/vllm-fork/tree/fused-moe-ar +- Build and Install vLLM + ``` +https://github.com/yiliu30/vllm-fork/tree/fused-moe-ar VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv ``` -- Allow python patches vLLM with vLLM-Ext +- Apply vLLM-Ext Patches(allow python recognize them) ``` cd auto-round/auto_round_extension/vllm_ext source apply_ext.sh ``` -- Enable vLLM-Ext +- Enable vLLM-Ext at Runtime ```bash VLLM_ENABLE_AR_EXT=1 vllm serve ... ``` \ No newline at end of file From 3af54ccfe1fd68568ee5a395b40867901b79b6de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Nov 2025 07:06:30 +0000 Subject: [PATCH 41/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/vllm_ext/moe_impl_mxfp4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index f6ee409a1..0ff5241ab 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -179,6 +179,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo OCP_MX_DTYPES, OCP_MX_Scheme, ) + ocp_config.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( input_dtype="mxfp8_e4m3", weight_dtype="mxfp8_e4m3", @@ -186,7 +187,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo ) logger.warning_once(f"Set OCP_MX_Scheme to {ocp_config.ocp_mx_scheme}") return ocp_config - + self.input_dtype = "mxfp4" self.weight_dtype = "mxfp4" return ocp_mx_moe_quant_config( From a29fd0a9cb8ff740831ff770f056eaed49842fcf Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 19 Nov 2025 22:04:07 -0800 Subject: [PATCH 42/52] add moe mxfp8 Signed-off-by: yiliu30 --- .../vllm_ext/moe_impl_mxfp8.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 auto_round_extension/vllm_ext/moe_impl_mxfp8.py diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp8.py b/auto_round_extension/vllm_ext/moe_impl_mxfp8.py new file mode 100644 index 000000000..bf0c0a579 --- /dev/null +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp8.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +import vllm.envs as envs +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.utils import set_weight_attrs + +import auto_round_extension.vllm_ext.mxfp4_qdq_utils as mxfp4_utils +from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( + dequant_mxfp4_to_fp8, + mxfp4_gemm_with_unpacked_weight, + run_mxfp4_emulations, +) +from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod + +logger = init_logger(__name__) + + +def apply_act(local_w1_out: torch.Tensor, local_w3_out: torch.Tensor, activation: str) -> torch.Tensor: + if activation == "silu": + act_fn = F.silu + w13_out = act_fn(local_w1_out) * local_w3_out + elif activation == "swigluoai": + limit = 7.0 + alpha = 1.702 + local_w1_out = local_w1_out.clamp(min=None, max=limit) + local_w3_out = local_w3_out.clamp(min=-limit, max=limit) + glu = (local_w1_out) * F.sigmoid(local_w1_out * alpha) + w13_out = (local_w3_out + 1) * glu + else: + raise NotImplementedError(f"Activation {activation} is not implemented.") + return w13_out + + +class AutoRoundMoEMethodMXFp8Impl(AutoRoundMoEMethod): + def __init__( + self, + quant_config: "AutoRoundConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.use_marlin = False + self.group_size = 32 + self.quant_config = quant_config + self.has_bias = self.moe.has_bias + + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.uint8, # E8M0 for MXFP8 scale + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + # w2 + w2_weight_scale = torch.nn.Parameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.uint8, # E8M0 for MXFP8 scale + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSORGROUP quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + + + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + from vllm.model_executor.layers.fused_moe.config import ( + ocp_mx_moe_quant_config, + ) + + self.input_dtype = "mxfp8_e4m3" + self.weight_dtype = "mxfp8_e4m3" + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, + block_shape=None, + ) + return None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return + + @torch.inference_mode() + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + topk_weights, topk_ids, _ = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + assert self.fused_experts is None + + from vllm.model_executor.layers.fused_moe import fused_experts + + w1 = layer.w13_weight + w2 = layer.w2_weight + out = fused_experts( + x, + w1, + w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) + return out From 2d37e637dea98c0f0419014b621418343ad3b83c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 06:05:04 +0000 Subject: [PATCH 43/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/vllm_ext/moe_impl_mxfp8.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp8.py b/auto_round_extension/vllm_ext/moe_impl_mxfp8.py index bf0c0a579..65a62337c 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp8.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp8.py @@ -66,7 +66,6 @@ def __init__( self.quant_config = quant_config self.has_bias = self.moe.has_bias - def create_weights( self, layer: torch.nn.Module, @@ -132,14 +131,10 @@ def create_weights( ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSORGROUP quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) + extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: from vllm.model_executor.layers.fused_moe.config import ( ocp_mx_moe_quant_config, @@ -161,7 +156,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMo return None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - return + return @torch.inference_mode() def apply( From 9841c35d8f61fe75b2eda1d1cc4a2780cf5e8270 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 24 Nov 2025 21:56:46 -0800 Subject: [PATCH 44/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/linear_impl_mxfp4.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py index d90b73ea5..04d5e20f8 100644 --- a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py @@ -87,11 +87,6 @@ def create_weights( def process_weights_after_loading(self, layer) -> None: # FIXME: may dequant to bf16 if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: - from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( - dequant_mxfp4_to_fp8, - mxfp4_gemm_with_unpacked_weight, - run_mxfp4_emulations, - ) weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( data_lp=layer.weight_packed, From 52a02a85a1c17280414980c25b85b98a7d63a344 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 24 Nov 2025 21:59:13 -0800 Subject: [PATCH 45/52] clean Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/moe_impl_mxfp8.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp8.py b/auto_round_extension/vllm_ext/moe_impl_mxfp8.py index bf0c0a579..a5d6965ff 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp8.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp8.py @@ -16,8 +16,6 @@ import torch import torch.nn.functional as F -import vllm.envs as envs -from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -27,12 +25,6 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.utils import set_weight_attrs -import auto_round_extension.vllm_ext.mxfp4_qdq_utils as mxfp4_utils -from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( - dequant_mxfp4_to_fp8, - mxfp4_gemm_with_unpacked_weight, - run_mxfp4_emulations, -) from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod logger = init_logger(__name__) From d03ac8bdaefb9fec15b8ace377f80672bb566633 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 00:46:56 -0800 Subject: [PATCH 46/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/README.md | 8 +- auto_round_extension/vllm_ext/apply_ext.sh | 46 ---- .../vllm_ext/mxfp8_qdq_utils.py | 9 +- .../vllm_ext/torchao_patch.py | 257 ------------------ auto_round_extension/vllm_ext/utils.py | 63 ++++- 5 files changed, 69 insertions(+), 314 deletions(-) delete mode 100644 auto_round_extension/vllm_ext/apply_ext.sh delete mode 100644 auto_round_extension/vllm_ext/torchao_patch.py diff --git a/auto_round_extension/vllm_ext/README.md b/auto_round_extension/vllm_ext/README.md index 790d7796b..1f0cb832f 100644 --- a/auto_round_extension/vllm_ext/README.md +++ b/auto_round_extension/vllm_ext/README.md @@ -1,14 +1,10 @@ - Build and Install vLLM ``` -https://github.com/yiliu30/vllm-fork/tree/fused-moe-ar +git clone --branch fused-moe-ar https://github.com/yiliu30/vllm-fork.git VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv ``` -- Apply vLLM-Ext Patches(allow python recognize them) -``` -cd auto-round/auto_round_extension/vllm_ext -source apply_ext.sh -``` + - Enable vLLM-Ext at Runtime ```bash diff --git a/auto_round_extension/vllm_ext/apply_ext.sh b/auto_round_extension/vllm_ext/apply_ext.sh deleted file mode 100644 index b327ec22c..000000000 --- a/auto_round_extension/vllm_ext/apply_ext.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Define the relative path for the `auto-round` installation -AUTO_ROUND_PATH="auto_round/../auto_round_extension/vllm_ext/sitecustomize.py" - -# Try to find the pip installation location -PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}') - -if [ -n "$PIP_LOCATION" ]; then - SITE_CUSTOMIZE_PATH="$PIP_LOCATION/$AUTO_ROUND_PATH" - echo "Checking for sitecustomize.py at: $SITE_CUSTOMIZE_PATH" - - if [ -f "$SITE_CUSTOMIZE_PATH" ]; then - echo "Found sitecustomize.py at: $SITE_CUSTOMIZE_PATH" - export PYTHONPATH=$(dirname "$SITE_CUSTOMIZE_PATH"):$PYTHONPATH - echo "PYTHONPATH set to: $PYTHONPATH" - return 0 2>/dev/null || true - fi -fi - -# Fallback: check current directory -LOCAL_SITE_CUSTOMIZE="./sitecustomize.py" -if [ -f "$LOCAL_SITE_CUSTOMIZE" ]; then - echo "Found sitecustomize.py at current directory." - export PYTHONPATH=$(pwd):$PYTHONPATH - echo "PYTHONPATH set to: $PYTHONPATH" - return 0 2>/dev/null || true -fi - -echo "Warning: sitecustomize.py not found in pip installation or current directory." -# Do not exit the shell -return 1 2>/dev/null || true \ No newline at end of file diff --git a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py index 6b7ee60d9..2ec365b30 100644 --- a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py @@ -52,14 +52,15 @@ def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size, target_dtype): return dequant_weight.to(target_dtype) + + + def quant_mx_fp8(tensor): - from auto_round_extension.vllm_ext.torchao_patch import ScaleCalculationMode, to_mx + from auto_round_extension.vllm_ext.utils import to_mx_fp8e4m3 - scale_e8m0_biased, data_lp = to_mx( + scale_e8m0_biased, data_lp = to_mx_fp8e4m3( data_hp=tensor, elem_dtype=torch.float8_e4m3fn, block_size=32, - scaling_mode=ScaleCalculationMode.RCEIL, - pack_fp6=False, ) return scale_e8m0_biased, data_lp diff --git a/auto_round_extension/vllm_ext/torchao_patch.py b/auto_round_extension/vllm_ext/torchao_patch.py deleted file mode 100644 index e8c891bb4..000000000 --- a/auto_round_extension/vllm_ext/torchao_patch.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum, auto -from typing import Union - -import torch - -from .utils import _to_mx_rceil, get_fp_scale - - -class ScaleCalculationMode(Enum): - """ - Enum representing the different methods for calculating MX block scaling. - There are three methods available: - FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). - It result in overflow issues for large values and bad for gradient quantization. - CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. - It uses X = 2^ceil(log2(max_abs(v))-max_exp). - EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). - It provides better accuracy for MX4 training compared to FLOOR and CEIL. - RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos. - This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization - Section "Computing scaling and conversion factors for FP8 with UE8M0 scales" - - By default, we use the EVEN method for better accuracy. - """ - - FLOOR = auto() - CEIL = auto() - EVEN = auto() - RCEIL = auto() - - -# This is conceptually an enum of non-core dtypes -# TODO(future PR): change to a cleaner way to represent this without -# regressing torch.compile and while keeping things readable. -DTYPE_FP6_E3M2 = "fp6_e3m2" -DTYPE_FP6_E2M3 = "fp6_e2m3" - -# Supported element dtypes -# TODO(future PR): add support for MX int8 -SUPPORTED_ELEM_DTYPES = [ - torch.float8_e4m3fn, - torch.float8_e5m2, - DTYPE_FP6_E2M3, - DTYPE_FP6_E3M2, -] - - -F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 -F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0 - -F8E4M3_MAX_POW2 = 8 # 256 -F8E5M2_MAX_POW2 = 15 # 32768 -F6_E2M3_MAX_POW2 = 2 # 4 -F6_E3M2_MAX_POW2 = 4 # 16 -F4_E2M1_MAX_POW2 = 2 # 4 - -E8M0_EXPONENT_BIAS = 127 -E8M0_EXPONENT_NAN_VAL = 255 - -F32_EXP_BIAS = 127 -BF16_EXP_BIAS = 127 -F6_E2M3_EXP_BIAS = 1 -F6_E3M2_EXP_BIAS = 3 -F4_E2M1_EXP_BIAS = 1 - -F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) - -F6_E2M3_MAX = 7.5 -F6_E2M3_MIN_NORMAL = 1.0 -F6_E2M3_MAX_INT = 31 # integer corresponding to 0b00011111 - -F6_E3M2_MAX = 28.0 -F6_E3M2_MIN_NORMAL = 0.25 -F6_E3M2_MAX_INT = 31 # integer corresponding to 0b00011111 - -F4_E2M1_MAX = 6.0 -F4_E2M1_MIN_NORMAL = 1.0 -F4_E2M1_MAX_INT = 7 - -BLOCK_SIZE_DEFAULT = 32 - - -# TODO(later): read from somewhere else? -SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 -EBITS_BF16, MBITS_BF16 = 8, 7 -EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 -EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 -EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 -EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 -EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 - - -def to_mx( - data_hp: torch.Tensor, - elem_dtype: Union[torch.dtype, str], - block_size: int, - scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, - pack_fp6: bool = False, -): - """ - Takes a high precision tensor and converts to MX scale and raw data, in - naive layout (scale and raw data are separate tensors). - """ - - assert data_hp.dtype in ( - torch.bfloat16, - torch.float, - ), f"{data_hp.dtype} is not supported yet" - # TODO(future PR): consider supporting padding - data_hp = data_hp.contiguous() - assert data_hp.numel() % block_size == 0, "unsupported" - assert data_hp.is_contiguous(), "unsupported" - assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" - - # calculate the scale in e8m0 format - - orig_shape = data_hp.shape - data_hp = data_hp.reshape(-1, block_size) - - # find max value of the data - # Note: this only implements the `minimally supported` version of - # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf - # section 6.3. - max_abs = torch.amax(torch.abs(data_hp), 1) - - # Add an epsilon to prevent the log2 function call for returning -inf - # where the values are zero. - eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) - - # Set X to be the largest power-of-two less than or equal to - # max_abs(v), divided by the largest power of two representable - # in the element data type, and get the mbits at the same time - if elem_dtype == torch.float8_e4m3fn: - target_max_pow2 = F8E4M3_MAX_POW2 - mbits = MBITS_F8_E4M3 - max_pos = F8E4M3_MAX - else: - raise AssertionError("unsupported element dtype") - - if scaling_mode == ScaleCalculationMode.RCEIL: - scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) - else: - if data_hp.dtype is torch.float32: - hp_int_dtype = torch.int32 - hp_mbits = MBITS_F32 - hp_ebits = EBITS_F32 - hp_exp_bias = F32_EXP_BIAS - else: - assert data_hp.dtype is torch.bfloat16 - hp_int_dtype = torch.int16 - hp_mbits = MBITS_BF16 - hp_ebits = EBITS_BF16 - hp_exp_bias = BF16_EXP_BIAS - - # rounding before calculating the largest power of 2 - # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) - if scaling_mode == ScaleCalculationMode.EVEN: - nan_mask = torch.isnan(max_abs) - max_abs = max_abs.view(hp_int_dtype) - val_to_add = 1 << (hp_mbits - mbits - 1) - mask = ((1 << (hp_ebits + SBITS)) - 1) << hp_mbits - max_abs = (max_abs + val_to_add) & mask - max_abs = max_abs.view(data_hp.dtype) - max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device, dtype=max_abs.dtype) - - # Calculate the scale for different modes - max_abs_int32 = (max_abs + eps).view(hp_int_dtype) - extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias - - if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): - scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 - elif scaling_mode == ScaleCalculationMode.CEIL: - # round up: add one to scale if the mantissa is larger than 0 - # 0x7FFFFF is equal to 23 ones - mantissa_gt_one = (max_abs_int32 & 0x7FFFFF) > 0 - extracted_pow2 += mantissa_gt_one - scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 - else: - raise AssertionError("unsupported scaling calculation mode") - - # Clamp to exponents that can be represented in e8m0 - # add one to positive range to capture NaNs - scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1) - - # Create the biased e8m0 representation and cast it to 8 bits - scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS - scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8) - - # Conversion to torch.uint8 sets NaN values to 0, fix this by - # explicitly setting known NaN values to 255 - scale_e8m0_biased = torch.where( - torch.isnan(max_abs), - E8M0_EXPONENT_NAN_VAL, - scale_e8m0_biased, - ) - - # For now, calculate the scale in floating point. - scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view(torch.float32) - - # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the - # float32 denormal range. For now, manually adjust the fp scale. This is - # relevant if all of the incoming block values are zeroes. - # See https://github.com/pytorch/pytorch/issues/125557 for details. - # Note: it would be more correct to set the minimum to 2**-127, but this - # does not work in triton either as it looks like subnormal value handling - # has some gaps. So, for now just set to the minimum normal value. - scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) - - # scale and saturated cast the data elements to max of target dtype - data_lp = data_hp / scale_fp32.unsqueeze(1) - - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) and not torch._dynamo.is_compiling(): - # As of 20250317, the Pytorch eager mode cast to `torch.float8_e4m3fn` - # is unsaturated. This cast is saturated in triton. If we are compute bound, - # we see a speedup if we remove this redundant clamp if we are compiling - # to triton. - # TODO(#1912): make the saturated cast work in eager mode and remove this - # workaround. - data_lp = torch.clamp(data_lp, min=-1 * max_pos, max=max_pos) - - # cast to target dtype - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - data_lp = data_lp.to(elem_dtype) - # need to reshape at the end to help inductor fuse things - data_lp = data_lp.reshape(orig_shape) - else: - raise AssertionError("unsupported") - - # scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) - return scale_e8m0_biased, data_lp - - -def down_size(size): - assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" - return (*size[:-1], size[-1] // 2) - - -def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: - # converting to uint8 for operations - shape = uint8_data.shape - assert shape[-1] % 2 == 0 - uint8_data = uint8_data.contiguous().view(-1) - return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) diff --git a/auto_round_extension/vllm_ext/utils.py b/auto_round_extension/vllm_ext/utils.py index f499a1c02..69fd70f54 100644 --- a/auto_round_extension/vllm_ext/utils.py +++ b/auto_round_extension/vllm_ext/utils.py @@ -13,7 +13,7 @@ # limitations under the License. import torch - +from typing import Union from auto_round.schemes import QuantizationScheme E8M0_EXPONENT_BIAS = 127 @@ -106,3 +106,64 @@ def _to_mx_rceil( # scale and saturated cast the data elements to max of target dtype data_lp = torch.clamp(data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos) return exponent, data_lp + + +def to_mx_fp8e4m3( + data_hp: torch.Tensor, + elem_dtype: Union[torch.dtype, str], + block_size: int, + +): + """ + Takes a high precision tensor and converts to MX scale and raw data, in + naive layout (scale and raw data are separate tensors). + """ + + assert data_hp.dtype in ( + torch.bfloat16, + torch.float, + ), f"{data_hp.dtype} is not supported yet" + # TODO(future PR): consider supporting padding + data_hp = data_hp.contiguous() + + # calculate the scale in e8m0 format + orig_shape = data_hp.shape + data_hp = data_hp.reshape(-1, block_size) + + # find max value of the data + # Note: this only implements the `minimally supported` version of + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + # section 6.3. + max_abs = torch.amax(torch.abs(data_hp), 1) + + # Set X to be the largest power-of-two less than or equal to + # max_abs(v), divided by the largest power of two representable + # in the element data type, and get the mbits at the same time + assert elem_dtype == torch.float8_e4m3fn, f"only float8_e4m3fn is supported now, got {elem_dtype}" + + max_pos = torch.finfo(torch.float8_e4m3fn).max # 448.0 + + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + + + data_lp = data_lp.to(elem_dtype) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) + + + # scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + return scale_e8m0_biased, data_lp + + + +def down_size(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + + +def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) From f3bb6492a9d87f73012791cc24437edf93a25fcf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 08:49:52 +0000 Subject: [PATCH 47/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/vllm_ext/mxfp8_qdq_utils.py | 3 --- auto_round_extension/vllm_ext/utils.py | 10 ++++------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py index 2ec365b30..482d77f85 100644 --- a/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp8_qdq_utils.py @@ -52,9 +52,6 @@ def dequant_mx_fp8(weight_fp8, scale_e8m0, block_size, target_dtype): return dequant_weight.to(target_dtype) - - - def quant_mx_fp8(tensor): from auto_round_extension.vllm_ext.utils import to_mx_fp8e4m3 diff --git a/auto_round_extension/vllm_ext/utils.py b/auto_round_extension/vllm_ext/utils.py index 69fd70f54..c817dba76 100644 --- a/auto_round_extension/vllm_ext/utils.py +++ b/auto_round_extension/vllm_ext/utils.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from typing import Union + +import torch + from auto_round.schemes import QuantizationScheme E8M0_EXPONENT_BIAS = 127 @@ -112,7 +114,6 @@ def to_mx_fp8e4m3( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int, - ): """ Takes a high precision tensor and converts to MX scale and raw data, in @@ -145,16 +146,13 @@ def to_mx_fp8e4m3( scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) - data_lp = data_lp.to(elem_dtype) # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) - # scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) return scale_e8m0_biased, data_lp - - + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" From 97df38086de1625e813765b7dda15c76796fec67 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 19:39:29 -0800 Subject: [PATCH 48/52] fix import Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/__init__.py | 8 ++------ auto_round_extension/vllm_ext/auto_round_ext.py | 14 +++++++++----- auto_round_extension/vllm_ext/tests/test_models.py | 4 +++- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 3e145334f..5f9f89271 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -18,9 +18,5 @@ def apply(): - import vllm.model_executor.layers.quantization.auto_round as auto_round_module - - from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig - - auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig - from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables + import auto_round_extension.vllm_ext.auto_round_ext + import auto_round_extension.vllm_ext.envs_ext diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index 23d441daa..83b4ca212 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -18,7 +18,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig +from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig as _BaseAutoRoundConfig from auto_round.schemes import QuantizationScheme from auto_round_extension.vllm_ext.quant_method_linear import AutoRoundQuantLinearMethod @@ -27,9 +27,9 @@ logger = init_logger(__name__) -class AutoRoundExtensionConfig(AutoRoundConfig): - SUPPORTED_DTYPES = AutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"}) - SUPPORTED_FORMATS = AutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"}) +class AutoRoundExtensionConfig(_BaseAutoRoundConfig): + SUPPORTED_DTYPES = _BaseAutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"}) + SUPPORTED_FORMATS = _BaseAutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"}) def get_quant_method(self, layer: torch.nn.Module, prefix: str): # FIXME: (yi) make it compatible with `AutoRoundConfig` @@ -49,7 +49,7 @@ def _parse_quant_scheme(config: dict): return quant_scheme @classmethod - def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig: + def from_config(cls, config: dict[str, Any]) -> _BaseAutoRoundConfig: ar_config = super().from_config(config) # TODO: (yi) refine below implementation quant_scheme = AutoRoundExtensionConfig._parse_quant_scheme(config) @@ -62,3 +62,7 @@ def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig: ar_config.quant_scheme = quant_scheme ar_config.layer_schemes = layer_schemes return ar_config + +# Patch vLLM’s AutoRoundConfig at import time +import vllm.model_executor.layers.quantization.auto_round as _auto_round_module +_auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig \ No newline at end of file diff --git a/auto_round_extension/vllm_ext/tests/test_models.py b/auto_round_extension/vllm_ext/tests/test_models.py index df748eb10..2b1343235 100644 --- a/auto_round_extension/vllm_ext/tests/test_models.py +++ b/auto_round_extension/vllm_ext/tests/test_models.py @@ -20,7 +20,9 @@ # "/data5/yliu7/HF_HOME/Qwen2.5-0.5B-Instruct-test-FP8_STATIC-fp8kv/" # "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4", # "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP4-fp8attention", - "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP8" + # "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP8" + "/home/yiliu7/workspace/inc/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qmodels/quantized_model_qwen_mxfp4", + "/home/yiliu7/workspace/inc/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qmodels/quantized_model_qwen_mxfp8", ] From 706a075071fe6711b658bd029dda727d6a097e48 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 19:40:33 -0800 Subject: [PATCH 49/52] fix model path Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round_extension/vllm_ext/tests/test_models.py b/auto_round_extension/vllm_ext/tests/test_models.py index 2b1343235..c01ce11c3 100644 --- a/auto_round_extension/vllm_ext/tests/test_models.py +++ b/auto_round_extension/vllm_ext/tests/test_models.py @@ -21,8 +21,8 @@ # "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4", # "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP4-fp8attention", # "/data6/yiliu4/Llama-3.2-1B-Instruct-MXFP8" - "/home/yiliu7/workspace/inc/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qmodels/quantized_model_qwen_mxfp4", - "/home/yiliu7/workspace/inc/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qmodels/quantized_model_qwen_mxfp8", + "/storage/yiliu7/ar_vllm_ext/quantized_model_qwen_mxfp4", + "/storage/yiliu7/ar_vllm_ext/quantized_model_qwen_mxfp8", ] From 458034cfc9945ce41f760762b4c3eae28e7d49ee Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 19:47:05 -0800 Subject: [PATCH 50/52] update ut Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/__init__.py | 4 ++++ auto_round_extension/vllm_ext/tests/test_models.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 5f9f89271..217869e01 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -20,3 +20,7 @@ def apply(): import auto_round_extension.vllm_ext.auto_round_ext import auto_round_extension.vllm_ext.envs_ext + print("*****************************************************************************") + print(f"* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *") + print("*****************************************************************************") + diff --git a/auto_round_extension/vllm_ext/tests/test_models.py b/auto_round_extension/vllm_ext/tests/test_models.py index c01ce11c3..faa783cde 100644 --- a/auto_round_extension/vllm_ext/tests/test_models.py +++ b/auto_round_extension/vllm_ext/tests/test_models.py @@ -26,6 +26,16 @@ ] +@pytest.fixture(autouse=True) +def set_vllm_ar_env(monkeypatch): + monkeypatch.setenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") + monkeypatch.setenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "1") + monkeypatch.setenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") + monkeypatch.setenv("VLLM_ENABLE_STATIC_MOE", "0") + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0") + monkeypatch.setenv("VLLM_ENABLE_AR_EXT", "1") + + @pytest.mark.skipif( not current_platform.is_cuda(), reason="only supports CUDA backend.", From a1669c2c1d4c13ee99b1ceecb80ff4e9110c1b08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 03:48:07 +0000 Subject: [PATCH 51/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/vllm_ext/__init__.py | 4 ++-- auto_round_extension/vllm_ext/auto_round_ext.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 217869e01..748e917e2 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -20,7 +20,7 @@ def apply(): import auto_round_extension.vllm_ext.auto_round_ext import auto_round_extension.vllm_ext.envs_ext + print("*****************************************************************************") - print(f"* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *") + print("* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *") print("*****************************************************************************") - diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index 83b4ca212..d665fd568 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -63,6 +63,8 @@ def from_config(cls, config: dict[str, Any]) -> _BaseAutoRoundConfig: ar_config.layer_schemes = layer_schemes return ar_config + # Patch vLLM’s AutoRoundConfig at import time import vllm.model_executor.layers.quantization.auto_round as _auto_round_module -_auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig \ No newline at end of file + +_auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig From cabe4b4a17b1d87a563d047f0bb13f4277868d19 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 25 Nov 2025 20:45:37 -0800 Subject: [PATCH 52/52] update envs Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/envs_ext.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/auto_round_extension/vllm_ext/envs_ext.py b/auto_round_extension/vllm_ext/envs_ext.py index e35ed41bc..70fbb3238 100644 --- a/auto_round_extension/vllm_ext/envs_ext.py +++ b/auto_round_extension/vllm_ext/envs_ext.py @@ -21,10 +21,10 @@ # Define extra environment variables extra_environment_variables: dict[str, Callable[[], Any]] = { - "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"), - "VLLM_MXFP4_PRE_UNPACK_TO_FP8": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "0") in ("1", "true", "True"), - "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"), - "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"), + "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") in ("1", "true", "True"), + "VLLM_MXFP4_PRE_UNPACK_TO_FP8": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "1") in ("1", "true", "True"), + "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "0") in ("1", "true", "True"), + "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") in ("1", "true", "True"), "VLLM_AR_POST_PROCESS_GPTOSS": lambda: os.getenv("VLLM_AR_POST_PROCESS_GPTOSS", "0") in ("1", "true", "True"), } # Add the extra environment variables to vllm.envs