Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def __init__(self, *args, **kwargs):
basic.add_argument(
"--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. "
)
basic.add_argument(
"--static_kv_dtype", default=None, type=str, help="Data type for static quantize key and value. "
)

basic.add_argument(
"--static_attention_dtype ", default=None, type=str, help="Data type for static quantize attention. "
)

tuning = self.add_argument_group("Tuning Arguments")
tuning.add_argument(
Expand Down Expand Up @@ -599,6 +606,8 @@ def tune(args):
layer_config=layer_config,
model_dtype=args.model_dtype,
momentum=args.momentum,
static_kv_dtype=args.static_kv_dtype,
static_attention_dtype=args.static_attention_dtype,
)

model_name = args.model.rstrip("/")
Expand Down
13 changes: 12 additions & 1 deletion auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def __init__(
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
self.momentum = kwargs.pop("momentum", 0.0)
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
static_attention_dtype = kwargs.pop("static_attention_dtype", None)
Copy link
Contributor

@n1ck-guo n1ck-guo Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does __main__.py also need add this parameter?

model_dtype = kwargs.pop("model_dtype", None)
device = kwargs.pop("device", None)
if envs.AR_USE_MODELSCOPE:
Expand Down Expand Up @@ -356,6 +357,11 @@ def __init__(
if self.static_kv_dtype is not None:
logger.warning("The static kv is experimental and currently has limited support.")

# Attention static dtype
self.static_attention_dtype = static_attention_dtype
if self.static_attention_dtype is not None:
logger.warning("The static attention dtype is experimental and currently has limited support.")

self._set_amp_dtype()
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
if self.act_bits <= 8 and self.amp_dtype == torch.float16:
Expand Down Expand Up @@ -1004,7 +1010,12 @@ def quantize_and_save(
kwargs.pop("inplace", None)

# Perform model quantization
if self.static_kv_dtype is not None:
if self.static_attention_dtype is not None:
from auto_round.experimental.attention import attention_quant_ctx

with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype):
model, _ = self.quantize()
elif self.static_kv_dtype is not None:
from auto_round.experimental.kv_cache import kvcache_quant_context

with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype):
Expand Down
190 changes: 190 additions & 0 deletions auto_round/experimental/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTICE: The design adapted from:
# https://github.com/vllm-project/compressed-tensors/pull/491


import contextlib
import inspect
from functools import partial
from typing import Callable, Optional
from weakref import ref

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from auto_round.experimental.kv_cache import kvcache_quant_context
from auto_round.experimental.utils import (
is_attention_module,
per_tensor_fp8_qdq,
update_parameter_data,
)
from auto_round.utils import logger

__all__ = [
"QuantizedAttentionImpl",
"initialize_hooked_attention",
"IMPL_ATTR",
"attention_quant_ctx",
]


IMPL_ATTR = "impl"
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
QUERY_SCALE_NAME = "q_scale"
QUERY_MAX_NAME = "q_max"


class QuantizedAttentionImpl(torch.nn.Module):
"""
QuantizedAttentionImpl module which wraps the functionality of the original
attention implementation. Unlike the original attention function, this
implementation is a `torch.nn.Module` which can be hooked to trigger
transforms and calibration hooks.

This module works by being registered as a submodule to attention modules via
`initialize_hooked_attention`, registering a new attention implementation function
which calls this module, then setting the model attention implementation to the new
function. After triggering hooks and quantization, this module calls the original
attention implementation function.

:param attn_module: parent attention module
"""

_original_impl = "sdpa"

def __init__(self, config: PretrainedConfig, attn_module: Module):
super().__init__()
self.config = config
self.attn_module = ref(attn_module) # avoid circular references
# register query max
device = next(attn_module.parameters()).device
initial_max = torch.tensor([float("-inf")], device=device)
update_parameter_data(attn_module, initial_max, QUERY_MAX_NAME)
initial_scale = torch.tensor([0.0], device=device)
update_parameter_data(attn_module, initial_scale, QUERY_SCALE_NAME)

def forward(
self,
module: Module,
query: Tensor,
key: Tensor,
value: Tensor,
*args,
**kwargs,
):
cur_query_max = query.abs().max()
query_max = torch.max(
getattr(module, QUERY_MAX_NAME).data,
cur_query_max.detach().to(getattr(module, QUERY_MAX_NAME).data.device),
)
update_parameter_data(module, query_max, QUERY_MAX_NAME)
query, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max)
update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME)
# original attention
return ALL_ATTENTION_FUNCTIONS[self._original_impl](
module,
query,
key,
value,
*args,
**kwargs,
)


# ----- initialize ----- #


def _ct_hooked_attention(module: Module, *args, **kwargs):
if hasattr(module, IMPL_ATTR):
return module.impl(module, *args, **kwargs)
else:
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) # pylint: disable=E0601


def initialize_hooked_attention(module: Module, config):
"""
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
attached to attention

:param model: parent model of attention module
:param module: attention module to initialize with
"""
if not hasattr(module, IMPL_ATTR):
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(config, module))
if config._attn_implementation != HOOKED_ATTENTION_NAME:
# assumes only one model at a time
global _original_impl
_original_impl = config._attn_implementation
# Add new implementation to AttentionInterface(mapping)
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
config._attn_implementation = HOOKED_ATTENTION_NAME

# initialize_hooked_kv_cache(model, module)


def prep_attention_module_for_calibration(module: torch.nn.Module, config):
if is_attention_module(module):
logger.trace(f"Preparing attention module {module.__class__.__name__} for calibration")
initialize_hooked_attention(module, config)


# # ----- hooks ----- #


# def register_query_hook(module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]) -> RemovableHandle:
# """
# Register a hook which takes post-rope query states as an argument and
# returns the modified query states or `None`

# :param module: attention module to add hook to
# :param hook: query hook function
# """
# impl = getattr(module, IMPL_ATTR)

# def _hook(impl: QuantizedAttentionImpl, args, kwargs):
# bound = inspect.signature(impl.forward).bind(*args, **kwargs)
# value = hook(module, bound.arguments["query"])
# if value is not None:
# bound.arguments["query"] = value

# return bound.args, bound.kwargs

# return impl.register_forward_pre_hook(_hook, with_kwargs=True)


def clean_up_hooked_attention(module, model):
if is_attention_module(module):
# Cleanup phase: Restore the original attention implementation
if hasattr(model.config, "_attn_implementation") and hasattr(model, "_original_impl"):
model.config._attn_implementation = model._original_impl
del model._original_impl


@contextlib.contextmanager
def attention_quant_ctx(model: PreTrainedModel, static_attention_dtype=torch.float8_e4m3fn):
try:
# Setup phase: Initialize hooked attention
prepare_fn = partial(prep_attention_module_for_calibration, config=model.config)
model.apply(prepare_fn)
with kvcache_quant_context(model, static_kv_dtype=static_attention_dtype):
yield model
finally:
clean_fn = partial(clean_up_hooked_attention, model=model)
model.apply(clean_fn)
77 changes: 15 additions & 62 deletions auto_round/experimental/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
import torch
from transformers.cache_utils import DynamicCache

from auto_round.experimental.utils import (
is_attention_module,
normalize_static_kv_dtype,
per_tensor_fp8_qdq,
update_parameter_data,
)
from auto_round.utils import logger

__all__ = [
Expand Down Expand Up @@ -81,13 +87,6 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
return lst


def fp8_per_tensor_qdq(tensor):
from auto_round.data_type.fp8 import quant_fp8_sym

qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0)
return qdq_tensor, scale


class QuantizedKVParameterCache(DynamicCache):
"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Expand Down Expand Up @@ -173,8 +172,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_
assert kv_type == KVCacheScaleType.VALUE
scales = self.v_scales

qdq_tensor, scale = fp8_per_tensor_qdq(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale)
qdq_tensor, scale = per_tensor_fp8_qdq(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0))
return qdq_tensor


Expand All @@ -192,13 +191,9 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4
quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype)
setattr(module, "kv_cache", quantized_kv_cache)
logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")


def is_attention_module(module: torch.nn.Module):
# FIXME: Handle this better.
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
)
init_scale = torch.tensor([0.0], device=next(module.parameters()).device)
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value)
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value)


def calibrate_kv_cache_input_hook(
Expand All @@ -209,7 +204,6 @@ def calibrate_kv_cache_input_hook(
kv_cache quantization. Will update the passed in
kv_cache to singleton QuantizedKVParameterCache.
"""
logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
kv_cache = getattr(module, "kv_cache")
# Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
# https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
Expand All @@ -221,33 +215,14 @@ def calibrate_kv_cache_input_hook(
return args, kwargs


def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str):
"""
Update the data of a parameter in a module.
If the parameter does not exist, it will be created.
"""
if hasattr(module, name):
param = getattr(module, name)
if isinstance(param, torch.nn.Parameter):
param.data = new_val
else:
module.register_parameter(name, torch.nn.Parameter(new_val))
else:
logger.warning(
"Parameter %s not found in module %s, creating new parameter."
% (name, module.__class__.__name__ + str(getattr(module, "layer_idx", "")))
)
module.register_parameter(name, torch.nn.Parameter(new_val))


def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
logger.debug(
"Calibrate kv_cache output hook for %s %s"
% (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
)
# logger.debug(
# "Calibrate kv_cache output hook for %s %s"
# % (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
# )
kv_cache = getattr(module, "kv_cache")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
Expand All @@ -261,28 +236,6 @@ def prep_attention_module_for_calibration(module: torch.nn.Module):
module.register_forward_hook(calibrate_kv_cache_output_hook)


def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype:
valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"]
valid_torch_dtype = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"fp8": torch.float8_e4m3fn,
"float8_e4m3fn": torch.float8_e4m3fn,
"float32": torch.float32,
"float": torch.float32, # Alias for float32
}
if static_kv_dtype in valid_dtype_name_lst:
new_dtype = valid_torch_dtype[static_kv_dtype]
elif static_kv_dtype in valid_torch_dtype.values():
new_dtype = static_kv_dtype
else:
raise ValueError(
f"Invalid static kv dtype: {static_kv_dtype}. "
f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}."
)
return new_dtype


@contextlib.contextmanager
def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn):
"""Context manager for FP8 KV cache quantization operations."""
Expand Down
4 changes: 2 additions & 2 deletions auto_round/experimental/qmodules/fp8_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor):

@torch.no_grad()
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:

original_dtype = bf16_input.dtype
qdq_input = self.qdq_input(bf16_input)
qdq_weight = self.dequant_weight_online()
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias)
return out
Loading
Loading