From 861f9bbd44e9bca3c3992cecbc9af4986d11370c Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Mon, 24 Nov 2025 16:50:27 +0800 Subject: [PATCH] Revert "Add static FP8 attention support (#1045)" This reverts commit c5b1c4129021dc2f920673b1011fa7fbad4e6095. --- auto_round/__main__.py | 9 - auto_round/compressors/base.py | 13 +- auto_round/experimental/attention.py | 190 ------------------ auto_round/experimental/kv_cache.py | 77 +++++-- .../experimental/qmodules/fp8_static.py | 4 +- auto_round/experimental/utils.py | 75 ------- test/test_cpu/test_export.py | 38 +--- 7 files changed, 68 insertions(+), 338 deletions(-) delete mode 100644 auto_round/experimental/attention.py delete mode 100644 auto_round/experimental/utils.py diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 2965b7073..e73e5461e 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -153,13 +153,6 @@ def __init__(self, *args, **kwargs): basic.add_argument( "--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. " ) - basic.add_argument( - "--static_kv_dtype", default=None, type=str, help="Data type for static quantize key and value. " - ) - - basic.add_argument( - "--static_attention_dtype ", default=None, type=str, help="Data type for static quantize attention. " - ) tuning = self.add_argument_group("Tuning Arguments") tuning.add_argument( @@ -606,8 +599,6 @@ def tune(args): layer_config=layer_config, model_dtype=args.model_dtype, momentum=args.momentum, - static_kv_dtype=args.static_kv_dtype, - static_attention_dtype=args.static_attention_dtype, ) model_name = args.model.rstrip("/") diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8936b1e75..077719df1 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -237,7 +237,6 @@ def __init__( enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) self.momentum = kwargs.pop("momentum", 0.0) static_kv_dtype = kwargs.pop("static_kv_dtype", None) - static_attention_dtype = kwargs.pop("static_attention_dtype", None) model_dtype = kwargs.pop("model_dtype", None) device = kwargs.pop("device", None) if envs.AR_USE_MODELSCOPE: @@ -357,11 +356,6 @@ def __init__( if self.static_kv_dtype is not None: logger.warning("The static kv is experimental and currently has limited support.") - # Attention static dtype - self.static_attention_dtype = static_attention_dtype - if self.static_attention_dtype is not None: - logger.warning("The static attention dtype is experimental and currently has limited support.") - self._set_amp_dtype() self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device if self.act_bits <= 8 and self.amp_dtype == torch.float16: @@ -1010,12 +1004,7 @@ def quantize_and_save( kwargs.pop("inplace", None) # Perform model quantization - if self.static_attention_dtype is not None: - from auto_round.experimental.attention import attention_quant_ctx - - with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype): - model, _ = self.quantize() - elif self.static_kv_dtype is not None: + if self.static_kv_dtype is not None: from auto_round.experimental.kv_cache import kvcache_quant_context with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype): diff --git a/auto_round/experimental/attention.py b/auto_round/experimental/attention.py deleted file mode 100644 index 2f82bc655..000000000 --- a/auto_round/experimental/attention.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# NOTICE: The design adapted from: -# https://github.com/vllm-project/compressed-tensors/pull/491 - - -import contextlib -import inspect -from functools import partial -from typing import Callable, Optional -from weakref import ref - -import torch -from torch import Tensor -from torch.nn import Module -from torch.utils.hooks import RemovableHandle -from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - -from auto_round.experimental.kv_cache import kvcache_quant_context -from auto_round.experimental.utils import ( - is_attention_module, - per_tensor_fp8_qdq, - update_parameter_data, -) -from auto_round.utils import logger - -__all__ = [ - "QuantizedAttentionImpl", - "initialize_hooked_attention", - "IMPL_ATTR", - "attention_quant_ctx", -] - - -IMPL_ATTR = "impl" -HOOKED_ATTENTION_NAME = "ct_hooked_attention" -QUERY_SCALE_NAME = "q_scale" -QUERY_MAX_NAME = "q_max" - - -class QuantizedAttentionImpl(torch.nn.Module): - """ - QuantizedAttentionImpl module which wraps the functionality of the original - attention implementation. Unlike the original attention function, this - implementation is a `torch.nn.Module` which can be hooked to trigger - transforms and calibration hooks. - - This module works by being registered as a submodule to attention modules via - `initialize_hooked_attention`, registering a new attention implementation function - which calls this module, then setting the model attention implementation to the new - function. After triggering hooks and quantization, this module calls the original - attention implementation function. - - :param attn_module: parent attention module - """ - - _original_impl = "sdpa" - - def __init__(self, config: PretrainedConfig, attn_module: Module): - super().__init__() - self.config = config - self.attn_module = ref(attn_module) # avoid circular references - # register query max - device = next(attn_module.parameters()).device - initial_max = torch.tensor([float("-inf")], device=device) - update_parameter_data(attn_module, initial_max, QUERY_MAX_NAME) - initial_scale = torch.tensor([0.0], device=device) - update_parameter_data(attn_module, initial_scale, QUERY_SCALE_NAME) - - def forward( - self, - module: Module, - query: Tensor, - key: Tensor, - value: Tensor, - *args, - **kwargs, - ): - cur_query_max = query.abs().max() - query_max = torch.max( - getattr(module, QUERY_MAX_NAME).data, - cur_query_max.detach().to(getattr(module, QUERY_MAX_NAME).data.device), - ) - update_parameter_data(module, query_max, QUERY_MAX_NAME) - query, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max) - update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME) - # original attention - return ALL_ATTENTION_FUNCTIONS[self._original_impl]( - module, - query, - key, - value, - *args, - **kwargs, - ) - - -# ----- initialize ----- # - - -def _ct_hooked_attention(module: Module, *args, **kwargs): - if hasattr(module, IMPL_ATTR): - return module.impl(module, *args, **kwargs) - else: - return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) # pylint: disable=E0601 - - -def initialize_hooked_attention(module: Module, config): - """ - Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances - attached to attention - - :param model: parent model of attention module - :param module: attention module to initialize with - """ - if not hasattr(module, IMPL_ATTR): - module.register_module(IMPL_ATTR, QuantizedAttentionImpl(config, module)) - if config._attn_implementation != HOOKED_ATTENTION_NAME: - # assumes only one model at a time - global _original_impl - _original_impl = config._attn_implementation - # Add new implementation to AttentionInterface(mapping) - AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) - config._attn_implementation = HOOKED_ATTENTION_NAME - - # initialize_hooked_kv_cache(model, module) - - -def prep_attention_module_for_calibration(module: torch.nn.Module, config): - if is_attention_module(module): - logger.trace(f"Preparing attention module {module.__class__.__name__} for calibration") - initialize_hooked_attention(module, config) - - -# # ----- hooks ----- # - - -# def register_query_hook(module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]) -> RemovableHandle: -# """ -# Register a hook which takes post-rope query states as an argument and -# returns the modified query states or `None` - -# :param module: attention module to add hook to -# :param hook: query hook function -# """ -# impl = getattr(module, IMPL_ATTR) - -# def _hook(impl: QuantizedAttentionImpl, args, kwargs): -# bound = inspect.signature(impl.forward).bind(*args, **kwargs) -# value = hook(module, bound.arguments["query"]) -# if value is not None: -# bound.arguments["query"] = value - -# return bound.args, bound.kwargs - -# return impl.register_forward_pre_hook(_hook, with_kwargs=True) - - -def clean_up_hooked_attention(module, model): - if is_attention_module(module): - # Cleanup phase: Restore the original attention implementation - if hasattr(model.config, "_attn_implementation") and hasattr(model, "_original_impl"): - model.config._attn_implementation = model._original_impl - del model._original_impl - - -@contextlib.contextmanager -def attention_quant_ctx(model: PreTrainedModel, static_attention_dtype=torch.float8_e4m3fn): - try: - # Setup phase: Initialize hooked attention - prepare_fn = partial(prep_attention_module_for_calibration, config=model.config) - model.apply(prepare_fn) - with kvcache_quant_context(model, static_kv_dtype=static_attention_dtype): - yield model - finally: - clean_fn = partial(clean_up_hooked_attention, model=model) - model.apply(clean_fn) diff --git a/auto_round/experimental/kv_cache.py b/auto_round/experimental/kv_cache.py index 56ddb04dd..8a49f3072 100644 --- a/auto_round/experimental/kv_cache.py +++ b/auto_round/experimental/kv_cache.py @@ -24,12 +24,6 @@ import torch from transformers.cache_utils import DynamicCache -from auto_round.experimental.utils import ( - is_attention_module, - normalize_static_kv_dtype, - per_tensor_fp8_qdq, - update_parameter_data, -) from auto_round.utils import logger __all__ = [ @@ -87,6 +81,13 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: return lst +def fp8_per_tensor_qdq(tensor): + from auto_round.data_type.fp8 import quant_fp8_sym + + qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0) + return qdq_tensor, scale + + class QuantizedKVParameterCache(DynamicCache): """ Quantized KV cache used in the forward call based on HF's dynamic cache. @@ -172,8 +173,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_ assert kv_type == KVCacheScaleType.VALUE scales = self.v_scales - qdq_tensor, scale = per_tensor_fp8_qdq(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0)) + qdq_tensor, scale = fp8_per_tensor_qdq(tensor) + _pad_and_append_at_idx_(scales, layer_idx, scale) return qdq_tensor @@ -191,9 +192,13 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4 quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype) setattr(module, "kv_cache", quantized_kv_cache) logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") - init_scale = torch.tensor([0.0], device=next(module.parameters()).device) - update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value) - update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value) + + +def is_attention_module(module: torch.nn.Module): + # FIXME: Handle this better. + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + ) def calibrate_kv_cache_input_hook( @@ -204,6 +209,7 @@ def calibrate_kv_cache_input_hook( kv_cache quantization. Will update the passed in kv_cache to singleton QuantizedKVParameterCache. """ + logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") kv_cache = getattr(module, "kv_cache") # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 @@ -215,14 +221,33 @@ def calibrate_kv_cache_input_hook( return args, kwargs +def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str): + """ + Update the data of a parameter in a module. + If the parameter does not exist, it will be created. + """ + if hasattr(module, name): + param = getattr(module, name) + if isinstance(param, torch.nn.Parameter): + param.data = new_val + else: + module.register_parameter(name, torch.nn.Parameter(new_val)) + else: + logger.warning( + "Parameter %s not found in module %s, creating new parameter." + % (name, module.__class__.__name__ + str(getattr(module, "layer_idx", ""))) + ) + module.register_parameter(name, torch.nn.Parameter(new_val)) + + def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor): """ Hook to update k_scale and v_scale parameters when running kv_cache quantization. """ - # logger.debug( - # "Calibrate kv_cache output hook for %s %s" - # % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) - # ) + logger.debug( + "Calibrate kv_cache output hook for %s %s" + % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) + ) kv_cache = getattr(module, "kv_cache") k_scale = kv_cache.k_scales[module.layer_idx] v_scale = kv_cache.v_scales[module.layer_idx] @@ -236,6 +261,28 @@ def prep_attention_module_for_calibration(module: torch.nn.Module): module.register_forward_hook(calibrate_kv_cache_output_hook) +def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype: + valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"] + valid_torch_dtype = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float8_e4m3fn": torch.float8_e4m3fn, + "float32": torch.float32, + "float": torch.float32, # Alias for float32 + } + if static_kv_dtype in valid_dtype_name_lst: + new_dtype = valid_torch_dtype[static_kv_dtype] + elif static_kv_dtype in valid_torch_dtype.values(): + new_dtype = static_kv_dtype + else: + raise ValueError( + f"Invalid static kv dtype: {static_kv_dtype}. " + f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}." + ) + return new_dtype + + @contextlib.contextmanager def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): """Context manager for FP8 KV cache quantization operations.""" diff --git a/auto_round/experimental/qmodules/fp8_static.py b/auto_round/experimental/qmodules/fp8_static.py index ca013fae2..e7c55086d 100644 --- a/auto_round/experimental/qmodules/fp8_static.py +++ b/auto_round/experimental/qmodules/fp8_static.py @@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor): @torch.no_grad() def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: - original_dtype = bf16_input.dtype + qdq_input = self.qdq_input(bf16_input) qdq_weight = self.dequant_weight_online() - out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias) + out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) return out diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py deleted file mode 100644 index a26e384d9..000000000 --- a/auto_round/experimental/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from auto_round.utils import logger - - -def per_tensor_fp8_qdq( - tensor: torch.Tensor, tensor_max: None | torch.Tensor = None -) -> tuple[torch.Tensor, torch.Tensor]: - from auto_round.data_type.fp8 import quant_fp8_sym - - qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=tensor_max, group_size=0, v=0) - return qdq_tensor, scale - - -# @torch.compiler.disable -def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str): - """ - Update the data of a parameter in a module. - If the parameter does not exist, it will be created. - """ - if hasattr(module, name): - param = getattr(module, name) - if isinstance(param, torch.nn.Parameter): - param.data.copy_(new_val) - else: - module.register_parameter(name, torch.nn.Parameter(new_val)) - else: - logger.warning_once( - "Parameter %s not found in module %s, creating new parameter." - % (name, module.__class__.__name__ + str(getattr(module, "layer_idx", ""))) - ) - module.register_parameter(name, torch.nn.Parameter(new_val)) - - -def normalize_static_kv_dtype(static_kv_dtype: str | torch.dtype) -> torch.dtype: - valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"] - valid_torch_dtype = { - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "float8_e4m3fn": torch.float8_e4m3fn, - "float32": torch.float32, - "float": torch.float32, # Alias for float32 - } - if static_kv_dtype in valid_dtype_name_lst: - new_dtype = valid_torch_dtype[static_kv_dtype] - elif static_kv_dtype in valid_torch_dtype.values(): - new_dtype = static_kv_dtype - else: - raise ValueError( - f"Invalid static kv dtype: {static_kv_dtype}. " - # f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}." - ) - return new_dtype - - -def is_attention_module(module: torch.nn.Module): - # FIXME: Handle this better. - return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") - ) diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index a7adc8f5d..ea484316b 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -35,7 +35,7 @@ def __iter__(self): class TestAutoRound(unittest.TestCase): @classmethod def setUpClass(self): - model_name = "facebook/opt-125m" + model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" self.save_dir = "./saved" self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) @@ -272,8 +272,8 @@ def test_static_afp8_export(self, static_kv_dtype): if static_kv_dtype == "fp8": self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) - self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1])) - self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1, 1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").dtype, torch.float32) shutil.rmtree(quantized_model_path, ignore_errors=True) @@ -302,38 +302,6 @@ def test_static_afp8_export(self, static_kv_dtype): self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) shutil.rmtree(quantized_model_path, ignore_errors=True) - def test_static_fp8_attn(self): - import os - - from safetensors import safe_open - - model_name = "facebook/opt-125m" - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - autoround = AutoRound( - model, - self.tokenizer, - iters=0, - nsamples=2, - seqlen=2, - scheme="FP8_STATIC", - static_attention_dtype="fp8", - ) - quantized_model_path = "./saved" - autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") - f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") - self.assertIn("model.decoder.layers.8.self_attn.k_proj.input_scale", f.keys()) - self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) - self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1])) - self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) - check_attrs = ["k_scale", "v_scale", "q_scale"] - for attr in check_attrs: - weight_name = f"model.decoder.layers.8.self_attn.{attr}" - self.assertIn(weight_name, f.keys()) - self.assertEqual(f.get_tensor(weight_name).shape, torch.Size([1])) - self.assertEqual(f.get_tensor(weight_name).dtype, torch.float32) - - shutil.rmtree(quantized_model_path, ignore_errors=True) - if __name__ == "__main__": unittest.main()