diff --git a/auto_round/__main__.py b/auto_round/__main__.py index e73e5461e..2965b7073 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -153,6 +153,13 @@ def __init__(self, *args, **kwargs): basic.add_argument( "--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. " ) + basic.add_argument( + "--static_kv_dtype", default=None, type=str, help="Data type for static quantize key and value. " + ) + + basic.add_argument( + "--static_attention_dtype ", default=None, type=str, help="Data type for static quantize attention. " + ) tuning = self.add_argument_group("Tuning Arguments") tuning.add_argument( @@ -599,6 +606,8 @@ def tune(args): layer_config=layer_config, model_dtype=args.model_dtype, momentum=args.momentum, + static_kv_dtype=args.static_kv_dtype, + static_attention_dtype=args.static_attention_dtype, ) model_name = args.model.rstrip("/") diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 62c2f6a32..7c90ed7ac 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -237,6 +237,7 @@ def __init__( enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) self.momentum = kwargs.pop("momentum", 0.0) static_kv_dtype = kwargs.pop("static_kv_dtype", None) + static_attention_dtype = kwargs.pop("static_attention_dtype", None) model_dtype = kwargs.pop("model_dtype", None) device = kwargs.pop("device", None) if envs.AR_USE_MODELSCOPE: @@ -356,6 +357,11 @@ def __init__( if self.static_kv_dtype is not None: logger.warning("The static kv is experimental and currently has limited support.") + # Attention static dtype + self.static_attention_dtype = static_attention_dtype + if self.static_attention_dtype is not None: + logger.warning("The static attention dtype is experimental and currently has limited support.") + self._set_amp_dtype() self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device if self.act_bits <= 8 and self.amp_dtype == torch.float16: @@ -1004,7 +1010,12 @@ def quantize_and_save( kwargs.pop("inplace", None) # Perform model quantization - if self.static_kv_dtype is not None: + if self.static_attention_dtype is not None: + from auto_round.experimental.attention import attention_quant_ctx + + with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype): + model, _ = self.quantize() + elif self.static_kv_dtype is not None: from auto_round.experimental.kv_cache import kvcache_quant_context with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype): diff --git a/auto_round/experimental/attention.py b/auto_round/experimental/attention.py new file mode 100644 index 000000000..2f82bc655 --- /dev/null +++ b/auto_round/experimental/attention.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTICE: The design adapted from: +# https://github.com/vllm-project/compressed-tensors/pull/491 + + +import contextlib +import inspect +from functools import partial +from typing import Callable, Optional +from weakref import ref + +import torch +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from auto_round.experimental.kv_cache import kvcache_quant_context +from auto_round.experimental.utils import ( + is_attention_module, + per_tensor_fp8_qdq, + update_parameter_data, +) +from auto_round.utils import logger + +__all__ = [ + "QuantizedAttentionImpl", + "initialize_hooked_attention", + "IMPL_ATTR", + "attention_quant_ctx", +] + + +IMPL_ATTR = "impl" +HOOKED_ATTENTION_NAME = "ct_hooked_attention" +QUERY_SCALE_NAME = "q_scale" +QUERY_MAX_NAME = "q_max" + + +class QuantizedAttentionImpl(torch.nn.Module): + """ + QuantizedAttentionImpl module which wraps the functionality of the original + attention implementation. Unlike the original attention function, this + implementation is a `torch.nn.Module` which can be hooked to trigger + transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_attention`, registering a new attention implementation function + which calls this module, then setting the model attention implementation to the new + function. After triggering hooks and quantization, this module calls the original + attention implementation function. + + :param attn_module: parent attention module + """ + + _original_impl = "sdpa" + + def __init__(self, config: PretrainedConfig, attn_module: Module): + super().__init__() + self.config = config + self.attn_module = ref(attn_module) # avoid circular references + # register query max + device = next(attn_module.parameters()).device + initial_max = torch.tensor([float("-inf")], device=device) + update_parameter_data(attn_module, initial_max, QUERY_MAX_NAME) + initial_scale = torch.tensor([0.0], device=device) + update_parameter_data(attn_module, initial_scale, QUERY_SCALE_NAME) + + def forward( + self, + module: Module, + query: Tensor, + key: Tensor, + value: Tensor, + *args, + **kwargs, + ): + cur_query_max = query.abs().max() + query_max = torch.max( + getattr(module, QUERY_MAX_NAME).data, + cur_query_max.detach().to(getattr(module, QUERY_MAX_NAME).data.device), + ) + update_parameter_data(module, query_max, QUERY_MAX_NAME) + query, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max) + update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME) + # original attention + return ALL_ATTENTION_FUNCTIONS[self._original_impl]( + module, + query, + key, + value, + *args, + **kwargs, + ) + + +# ----- initialize ----- # + + +def _ct_hooked_attention(module: Module, *args, **kwargs): + if hasattr(module, IMPL_ATTR): + return module.impl(module, *args, **kwargs) + else: + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) # pylint: disable=E0601 + + +def initialize_hooked_attention(module: Module, config): + """ + Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances + attached to attention + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(config, module)) + if config._attn_implementation != HOOKED_ATTENTION_NAME: + # assumes only one model at a time + global _original_impl + _original_impl = config._attn_implementation + # Add new implementation to AttentionInterface(mapping) + AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) + config._attn_implementation = HOOKED_ATTENTION_NAME + + # initialize_hooked_kv_cache(model, module) + + +def prep_attention_module_for_calibration(module: torch.nn.Module, config): + if is_attention_module(module): + logger.trace(f"Preparing attention module {module.__class__.__name__} for calibration") + initialize_hooked_attention(module, config) + + +# # ----- hooks ----- # + + +# def register_query_hook(module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]) -> RemovableHandle: +# """ +# Register a hook which takes post-rope query states as an argument and +# returns the modified query states or `None` + +# :param module: attention module to add hook to +# :param hook: query hook function +# """ +# impl = getattr(module, IMPL_ATTR) + +# def _hook(impl: QuantizedAttentionImpl, args, kwargs): +# bound = inspect.signature(impl.forward).bind(*args, **kwargs) +# value = hook(module, bound.arguments["query"]) +# if value is not None: +# bound.arguments["query"] = value + +# return bound.args, bound.kwargs + +# return impl.register_forward_pre_hook(_hook, with_kwargs=True) + + +def clean_up_hooked_attention(module, model): + if is_attention_module(module): + # Cleanup phase: Restore the original attention implementation + if hasattr(model.config, "_attn_implementation") and hasattr(model, "_original_impl"): + model.config._attn_implementation = model._original_impl + del model._original_impl + + +@contextlib.contextmanager +def attention_quant_ctx(model: PreTrainedModel, static_attention_dtype=torch.float8_e4m3fn): + try: + # Setup phase: Initialize hooked attention + prepare_fn = partial(prep_attention_module_for_calibration, config=model.config) + model.apply(prepare_fn) + with kvcache_quant_context(model, static_kv_dtype=static_attention_dtype): + yield model + finally: + clean_fn = partial(clean_up_hooked_attention, model=model) + model.apply(clean_fn) diff --git a/auto_round/experimental/kv_cache.py b/auto_round/experimental/kv_cache.py index 8a49f3072..56ddb04dd 100644 --- a/auto_round/experimental/kv_cache.py +++ b/auto_round/experimental/kv_cache.py @@ -24,6 +24,12 @@ import torch from transformers.cache_utils import DynamicCache +from auto_round.experimental.utils import ( + is_attention_module, + normalize_static_kv_dtype, + per_tensor_fp8_qdq, + update_parameter_data, +) from auto_round.utils import logger __all__ = [ @@ -81,13 +87,6 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list: return lst -def fp8_per_tensor_qdq(tensor): - from auto_round.data_type.fp8 import quant_fp8_sym - - qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0) - return qdq_tensor, scale - - class QuantizedKVParameterCache(DynamicCache): """ Quantized KV cache used in the forward call based on HF's dynamic cache. @@ -173,8 +172,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_ assert kv_type == KVCacheScaleType.VALUE scales = self.v_scales - qdq_tensor, scale = fp8_per_tensor_qdq(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) + qdq_tensor, scale = per_tensor_fp8_qdq(tensor) + _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0)) return qdq_tensor @@ -192,13 +191,9 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4 quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype) setattr(module, "kv_cache", quantized_kv_cache) logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") - - -def is_attention_module(module: torch.nn.Module): - # FIXME: Handle this better. - return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") - ) + init_scale = torch.tensor([0.0], device=next(module.parameters()).device) + update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value) + update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value) def calibrate_kv_cache_input_hook( @@ -209,7 +204,6 @@ def calibrate_kv_cache_input_hook( kv_cache quantization. Will update the passed in kv_cache to singleton QuantizedKVParameterCache. """ - logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") kv_cache = getattr(module, "kv_cache") # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 @@ -221,33 +215,14 @@ def calibrate_kv_cache_input_hook( return args, kwargs -def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str): - """ - Update the data of a parameter in a module. - If the parameter does not exist, it will be created. - """ - if hasattr(module, name): - param = getattr(module, name) - if isinstance(param, torch.nn.Parameter): - param.data = new_val - else: - module.register_parameter(name, torch.nn.Parameter(new_val)) - else: - logger.warning( - "Parameter %s not found in module %s, creating new parameter." - % (name, module.__class__.__name__ + str(getattr(module, "layer_idx", ""))) - ) - module.register_parameter(name, torch.nn.Parameter(new_val)) - - def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor): """ Hook to update k_scale and v_scale parameters when running kv_cache quantization. """ - logger.debug( - "Calibrate kv_cache output hook for %s %s" - % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) - ) + # logger.debug( + # "Calibrate kv_cache output hook for %s %s" + # % (module.__class__.__name__, str(getattr(module, "layer_idx", None))) + # ) kv_cache = getattr(module, "kv_cache") k_scale = kv_cache.k_scales[module.layer_idx] v_scale = kv_cache.v_scales[module.layer_idx] @@ -261,28 +236,6 @@ def prep_attention_module_for_calibration(module: torch.nn.Module): module.register_forward_hook(calibrate_kv_cache_output_hook) -def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype: - valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"] - valid_torch_dtype = { - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "fp8": torch.float8_e4m3fn, - "float8_e4m3fn": torch.float8_e4m3fn, - "float32": torch.float32, - "float": torch.float32, # Alias for float32 - } - if static_kv_dtype in valid_dtype_name_lst: - new_dtype = valid_torch_dtype[static_kv_dtype] - elif static_kv_dtype in valid_torch_dtype.values(): - new_dtype = static_kv_dtype - else: - raise ValueError( - f"Invalid static kv dtype: {static_kv_dtype}. " - f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}." - ) - return new_dtype - - @contextlib.contextmanager def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): """Context manager for FP8 KV cache quantization operations.""" diff --git a/auto_round/experimental/qmodules/fp8_static.py b/auto_round/experimental/qmodules/fp8_static.py index e7c55086d..ca013fae2 100644 --- a/auto_round/experimental/qmodules/fp8_static.py +++ b/auto_round/experimental/qmodules/fp8_static.py @@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor): @torch.no_grad() def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: - + original_dtype = bf16_input.dtype qdq_input = self.qdq_input(bf16_input) qdq_weight = self.dequant_weight_online() - out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) + out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias) return out diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py new file mode 100644 index 000000000..a26e384d9 --- /dev/null +++ b/auto_round/experimental/utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from auto_round.utils import logger + + +def per_tensor_fp8_qdq( + tensor: torch.Tensor, tensor_max: None | torch.Tensor = None +) -> tuple[torch.Tensor, torch.Tensor]: + from auto_round.data_type.fp8 import quant_fp8_sym + + qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=tensor_max, group_size=0, v=0) + return qdq_tensor, scale + + +# @torch.compiler.disable +def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str): + """ + Update the data of a parameter in a module. + If the parameter does not exist, it will be created. + """ + if hasattr(module, name): + param = getattr(module, name) + if isinstance(param, torch.nn.Parameter): + param.data.copy_(new_val) + else: + module.register_parameter(name, torch.nn.Parameter(new_val)) + else: + logger.warning_once( + "Parameter %s not found in module %s, creating new parameter." + % (name, module.__class__.__name__ + str(getattr(module, "layer_idx", ""))) + ) + module.register_parameter(name, torch.nn.Parameter(new_val)) + + +def normalize_static_kv_dtype(static_kv_dtype: str | torch.dtype) -> torch.dtype: + valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"] + valid_torch_dtype = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "fp8": torch.float8_e4m3fn, + "float8_e4m3fn": torch.float8_e4m3fn, + "float32": torch.float32, + "float": torch.float32, # Alias for float32 + } + if static_kv_dtype in valid_dtype_name_lst: + new_dtype = valid_torch_dtype[static_kv_dtype] + elif static_kv_dtype in valid_torch_dtype.values(): + new_dtype = static_kv_dtype + else: + raise ValueError( + f"Invalid static kv dtype: {static_kv_dtype}. " + # f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}." + ) + return new_dtype + + +def is_attention_module(module: torch.nn.Module): + # FIXME: Handle this better. + return "attention" in module.__class__.__name__.lower() and ( + hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + ) diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index ea484316b..a7adc8f5d 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -35,7 +35,7 @@ def __iter__(self): class TestAutoRound(unittest.TestCase): @classmethod def setUpClass(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + model_name = "facebook/opt-125m" self.save_dir = "./saved" self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) @@ -272,8 +272,8 @@ def test_static_afp8_export(self, static_kv_dtype): if static_kv_dtype == "fp8": self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys()) self.assertIn("model.decoder.layers.8.self_attn.v_scale", f.keys()) - self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1, 1])) - self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1, 1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_scale").shape, torch.Size([1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").shape, torch.Size([1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.k_scale").dtype, torch.float32) shutil.rmtree(quantized_model_path, ignore_errors=True) @@ -302,6 +302,38 @@ def test_static_afp8_export(self, static_kv_dtype): self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) shutil.rmtree(quantized_model_path, ignore_errors=True) + def test_static_fp8_attn(self): + import os + + from safetensors import safe_open + + model_name = "facebook/opt-125m" + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) + autoround = AutoRound( + model, + self.tokenizer, + iters=0, + nsamples=2, + seqlen=2, + scheme="FP8_STATIC", + static_attention_dtype="fp8", + ) + quantized_model_path = "./saved" + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + f = safe_open(os.path.join(quantized_model_path, "model.safetensors"), framework="pt") + self.assertIn("model.decoder.layers.8.self_attn.k_proj.input_scale", f.keys()) + self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1])) + self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + check_attrs = ["k_scale", "v_scale", "q_scale"] + for attr in check_attrs: + weight_name = f"model.decoder.layers.8.self_attn.{attr}" + self.assertIn(weight_name, f.keys()) + self.assertEqual(f.get_tensor(weight_name).shape, torch.Size([1])) + self.assertEqual(f.get_tensor(weight_name).dtype, torch.float32) + + shutil.rmtree(quantized_model_path, ignore_errors=True) + if __name__ == "__main__": unittest.main()