-
Notifications
You must be signed in to change notification settings - Fork 60
[1/N] Initial vllm-ext evaluation support (MXFP4 MOE) #935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
5ee2c2d
init moe support
yiliu30 c278f9d
add test
yiliu30 418e6a0
fix import
yiliu30 184783f
clean envs
yiliu30 b9da06f
add script for apply ext
yiliu30 187f38d
clean docs
yiliu30 4031724
fix license
yiliu30 5fe01ef
fix
yiliu30 73f1e9b
fix import and sitecustomize
yiliu30 8495854
move to ext
yiliu30 c473934
update mxfp4
yiliu30 9f65bd1
fix
yiliu30 8038a5f
fix model name
yiliu30 e0872b6
Merge branch 'main' into vllm-ext
yiliu30 c82bce1
fix
yiliu30 19e18c7
Merge branch 'vllm-ext' of https://github.com/intel/auto-round into v…
yiliu30 adf7ebf
use absolute path
yiliu30 59f5cd2
Merge branch 'main' into vllm-ext
yiliu30 8f27041
Merge branch 'main' into vllm-ext
yiliu30 ad8537c
fix
yiliu30 77844f6
mark round method as todo
yiliu30 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ==---------------------------------------------------------------------------== | ||
| # Apply the extension | ||
| # ==---------------------------------------------------------------------------== | ||
|
|
||
|
|
||
| def apply(): | ||
| import vllm.model_executor.layers.quantization.auto_round as auto_round_module | ||
|
|
||
| from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig | ||
|
|
||
| auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig | ||
| from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Define the relative path for the `auto-round` installation | ||
| AUTO_ROUND_PATH="auto_round/../auto_round_extension/vllm_ext/sitecustomize.py" | ||
|
|
||
| # Try to find the pip installation location | ||
| PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}') | ||
|
|
||
| if [ -n "$PIP_LOCATION" ]; then | ||
| SITE_CUSTOMIZE_PATH="$PIP_LOCATION/$AUTO_ROUND_PATH" | ||
| echo "Checking for sitecustomize.py at: $SITE_CUSTOMIZE_PATH" | ||
|
|
||
| if [ -f "$SITE_CUSTOMIZE_PATH" ]; then | ||
| echo "Found sitecustomize.py at: $SITE_CUSTOMIZE_PATH" | ||
| export PYTHONPATH=$(dirname "$SITE_CUSTOMIZE_PATH"):$PYTHONPATH | ||
| echo "PYTHONPATH set to: $PYTHONPATH" | ||
| return 0 2>/dev/null || true | ||
| fi | ||
| fi | ||
|
|
||
| # Fallback: check current directory | ||
| LOCAL_SITE_CUSTOMIZE="./sitecustomize.py" | ||
| if [ -f "$LOCAL_SITE_CUSTOMIZE" ]; then | ||
| echo "Found sitecustomize.py at current directory." | ||
| export PYTHONPATH=$(pwd):$PYTHONPATH | ||
| echo "PYTHONPATH set to: $PYTHONPATH" | ||
| return 0 2>/dev/null || true | ||
| fi | ||
|
|
||
| echo "Warning: sitecustomize.py not found in pip installation or current directory." | ||
| # Do not exit the shell | ||
| return 1 2>/dev/null || true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any | ||
|
|
||
| import torch | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.fused_moe import FusedMoE | ||
| from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod | ||
| from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig | ||
|
|
||
| from auto_round.schemes import QuantizationScheme | ||
| from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class AutoRoundExtensionConfig(AutoRoundConfig): | ||
| SUPPORTED_DTYPES = AutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"}) | ||
| SUPPORTED_FORMATS = AutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"}) | ||
|
|
||
| def get_quant_method(self, layer: torch.nn.Module, prefix: str): | ||
| # FIXME: (yi) make it compatible with `AutoRoundConfig` | ||
| if isinstance(layer, FusedMoE): | ||
| quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix) | ||
| return quant_method | ||
| elif isinstance(layer, LinearBase): | ||
| return UnquantizedLinearMethod() | ||
| else: | ||
| return None | ||
|
|
||
| @staticmethod | ||
| def _parse_quant_scheme(config: dict): | ||
| quant_scheme_attrs = QuantizationScheme.get_attributes() | ||
| filter_config = {key: value for key, value in config.items() if key in quant_scheme_attrs} | ||
| quant_scheme = QuantizationScheme.from_dict(filter_config) | ||
| return quant_scheme | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig: | ||
| ar_config = super().from_config(config) | ||
| # TODO: (yi) refine below implementation | ||
| quant_scheme = AutoRoundExtensionConfig._parse_quant_scheme(config) | ||
| layer_schemes = {} | ||
| layer_schemes = {} # ensure dict | ||
| extra_config = getattr(ar_config, "extra_config", None) | ||
| if extra_config is not None: | ||
| for layer_name, layer_config in extra_config.items(): | ||
| layer_schemes[layer_name] = AutoRoundExtensionConfig._parse_quant_scheme(layer_config) | ||
| ar_config.quant_scheme = quant_scheme | ||
| ar_config.layer_schemes = layer_schemes | ||
| return ar_config | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| from typing import Any, Callable | ||
|
|
||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| # Define extra environment variables | ||
| extra_environment_variables: dict[str, Callable[[], Any]] = { | ||
| "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"), | ||
| "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"), | ||
| "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"), | ||
| } | ||
| # Add the extra environment variables to vllm.envs | ||
| import vllm.envs as envs | ||
| from vllm.envs import environment_variables | ||
|
|
||
| # Merge the environment variables | ||
| all_environment_variables = {**environment_variables, **extra_environment_variables} | ||
|
|
||
|
|
||
| for name, value_fn in extra_environment_variables.items(): | ||
| setattr(envs, name, value_fn()) | ||
|
|
||
| logger.warning_once(f"Added extra environment variables: {list(extra_environment_variables.keys())}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| # Module-level device tensor cache to fix cuda graph issue | ||
| _DEVICE_E2M1_TENSORS = {} | ||
|
|
||
| # Constants for FP4 values (E2M1 format) | ||
| _E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] | ||
|
|
||
|
|
||
| def get_e2m1_tensor(device): | ||
| """Get device-specific E2M1 lookup tensor, creating it if needed.""" | ||
| device_str = str(device) | ||
| if device_str not in _DEVICE_E2M1_TENSORS: | ||
| _DEVICE_E2M1_TENSORS[device_str] = torch.tensor(_E2M1_VALUES, dtype=torch.float32, device=device) | ||
| return _DEVICE_E2M1_TENSORS[device_str] | ||
|
|
||
|
|
||
| def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: | ||
| m, n = x.shape | ||
| device = x.device | ||
|
|
||
| # Create lookup table for FP4 values to indices | ||
| # Map the absolute values to 0-7 indices | ||
| kE2M1 = get_e2m1_tensor(x.device) | ||
|
|
||
| # Find closest valid FP4 value index for each element | ||
| abs_x = torch.abs(x) | ||
| abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8] | ||
| abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n] | ||
|
|
||
| # Apply sign bit (bit 3) to get final 4-bit representation | ||
| indices = abs_indices + (torch.signbit(x).to(torch.long) << 3) | ||
|
|
||
| # Reshape to prepare for packing pairs of values | ||
| indices = indices.reshape(-1) | ||
|
|
||
| # Handle odd length by padding if necessary | ||
| assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}" | ||
|
|
||
| # Reshape to pair consecutive elements | ||
| indices = indices.reshape(-1, 2) | ||
|
|
||
| # Pack pairs of 4-bit values into 8-bit values | ||
| packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) | ||
|
|
||
| return packed.reshape(m, n // 2) | ||
|
|
||
|
|
||
| def unpack_fp4_from_uint8( | ||
| a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values | ||
| (i.e. first four bits correspond to one fp4 value, last four correspond to a | ||
| consecutive fp4 value). The bits represent an index, which are mapped to an fp4 | ||
| value. | ||
|
|
||
| :param a: tensor to unpack | ||
| :param m: original dim 0 size of the unpacked tensor | ||
| :param n: original dim 1 size of the unpacked tensor | ||
| :param dtype: dense dtype to cast the unpacked tensor to | ||
| """ | ||
| assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}" | ||
|
|
||
| # Vectorized nibble processing | ||
| a_flat = a.flatten() | ||
| high = (a_flat & 0xF0) >> 4 # Upper nibbles | ||
| low = a_flat & 0x0F # Lower nibbles | ||
|
|
||
| # Combine nibbles for batch processing | ||
| combined = torch.stack((low, high), dim=1).flatten() | ||
|
|
||
| # Vectorized sign and magnitude extraction | ||
| signs = (combined & 0x08).to(torch.bool) # Sign bits | ||
| abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices | ||
|
|
||
| # Device-aware lookup and sign application | ||
| kE2M1 = get_e2m1_tensor(a.device) | ||
|
|
||
| values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) | ||
|
|
||
| # Reshape to final form | ||
| return values.reshape(m, n).to(dtype=dtype) | ||
|
|
||
|
|
||
| def cast_to_fp4(x): | ||
| sign = torch.sign(x) | ||
| x = torch.abs(x) | ||
| x[(x >= 0.0) & (x <= 0.25)] = 0.0 | ||
| x[(x > 0.25) & (x < 0.75)] = 0.5 | ||
| x[(x >= 0.75) & (x <= 1.25)] = 1.0 | ||
| x[(x > 1.25) & (x < 1.75)] = 1.5 | ||
| x[(x >= 1.75) & (x <= 2.5)] = 2.0 | ||
| x[(x > 2.5) & (x < 3.5)] = 3.0 | ||
| x[(x >= 3.5) & (x <= 5.0)] = 4.0 | ||
| x[x > 5.0] = 6.0 | ||
| return x * sign |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.