Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions auto_round_extension/vllm_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ==---------------------------------------------------------------------------==
# Apply the extension
# ==---------------------------------------------------------------------------==


def apply():
import vllm.model_executor.layers.quantization.auto_round as auto_round_module

from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig

auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig
from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables
46 changes: 46 additions & 0 deletions auto_round_extension/vllm_ext/apply_ext.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Define the relative path for the `auto-round` installation
AUTO_ROUND_PATH="auto_round/../auto_round_extension/vllm_ext/sitecustomize.py"

# Try to find the pip installation location
PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}')

if [ -n "$PIP_LOCATION" ]; then
SITE_CUSTOMIZE_PATH="$PIP_LOCATION/$AUTO_ROUND_PATH"
echo "Checking for sitecustomize.py at: $SITE_CUSTOMIZE_PATH"

if [ -f "$SITE_CUSTOMIZE_PATH" ]; then
echo "Found sitecustomize.py at: $SITE_CUSTOMIZE_PATH"
export PYTHONPATH=$(dirname "$SITE_CUSTOMIZE_PATH"):$PYTHONPATH
echo "PYTHONPATH set to: $PYTHONPATH"
return 0 2>/dev/null || true
fi
fi

# Fallback: check current directory
LOCAL_SITE_CUSTOMIZE="./sitecustomize.py"
if [ -f "$LOCAL_SITE_CUSTOMIZE" ]; then
echo "Found sitecustomize.py at current directory."
export PYTHONPATH=$(pwd):$PYTHONPATH
echo "PYTHONPATH set to: $PYTHONPATH"
return 0 2>/dev/null || true
fi

echo "Warning: sitecustomize.py not found in pip installation or current directory."
# Do not exit the shell
return 1 2>/dev/null || true
63 changes: 63 additions & 0 deletions auto_round_extension/vllm_ext/auto_round_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig

from auto_round.schemes import QuantizationScheme
from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod

logger = init_logger(__name__)


class AutoRoundExtensionConfig(AutoRoundConfig):
SUPPORTED_DTYPES = AutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"})
SUPPORTED_FORMATS = AutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"})

def get_quant_method(self, layer: torch.nn.Module, prefix: str):
# FIXME: (yi) make it compatible with `AutoRoundConfig`
if isinstance(layer, FusedMoE):
quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix)
return quant_method
elif isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
else:
return None

@staticmethod
def _parse_quant_scheme(config: dict):
quant_scheme_attrs = QuantizationScheme.get_attributes()
filter_config = {key: value for key, value in config.items() if key in quant_scheme_attrs}
quant_scheme = QuantizationScheme.from_dict(filter_config)
return quant_scheme

@classmethod
def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig:
ar_config = super().from_config(config)
# TODO: (yi) refine below implementation
quant_scheme = AutoRoundExtensionConfig._parse_quant_scheme(config)
layer_schemes = {}
layer_schemes = {} # ensure dict
extra_config = getattr(ar_config, "extra_config", None)
if extra_config is not None:
for layer_name, layer_config in extra_config.items():
layer_schemes[layer_name] = AutoRoundExtensionConfig._parse_quant_scheme(layer_config)
ar_config.quant_scheme = quant_scheme
ar_config.layer_schemes = layer_schemes
return ar_config
39 changes: 39 additions & 0 deletions auto_round_extension/vllm_ext/envs_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, Callable

from vllm.logger import init_logger

logger = init_logger(__name__)

# Define extra environment variables
extra_environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"),
"VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"),
"VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"),
}
# Add the extra environment variables to vllm.envs
import vllm.envs as envs
from vllm.envs import environment_variables

# Merge the environment variables
all_environment_variables = {**environment_variables, **extra_environment_variables}


for name, value_fn in extra_environment_variables.items():
setattr(envs, name, value_fn())

logger.warning_once(f"Added extra environment variables: {list(extra_environment_variables.keys())}")
113 changes: 113 additions & 0 deletions auto_round_extension/vllm_ext/fp4_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch

# Module-level device tensor cache to fix cuda graph issue
_DEVICE_E2M1_TENSORS = {}

# Constants for FP4 values (E2M1 format)
_E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]


def get_e2m1_tensor(device):
"""Get device-specific E2M1 lookup tensor, creating it if needed."""
device_str = str(device)
if device_str not in _DEVICE_E2M1_TENSORS:
_DEVICE_E2M1_TENSORS[device_str] = torch.tensor(_E2M1_VALUES, dtype=torch.float32, device=device)
return _DEVICE_E2M1_TENSORS[device_str]


def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = get_e2m1_tensor(x.device)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}"

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)

# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)

return packed.reshape(m, n // 2)


def unpack_fp4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
"""
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
(i.e. first four bits correspond to one fp4 value, last four correspond to a
consecutive fp4 value). The bits represent an index, which are mapped to an fp4
value.

:param a: tensor to unpack
:param m: original dim 0 size of the unpacked tensor
:param n: original dim 1 size of the unpacked tensor
:param dtype: dense dtype to cast the unpacked tensor to
"""
assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}"

# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles

# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()

# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices

# Device-aware lookup and sign application
kE2M1 = get_e2m1_tensor(a.device)

values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

# Reshape to final form
return values.reshape(m, n).to(dtype=dtype)


def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign
Loading