Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,6 @@ def __init__(self, *args, **kwargs):
basic.add_argument(
"--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. "
)
basic.add_argument(
"--static_kv_dtype", default=None, type=str, help="Data type for static quantize key and value. "
)

basic.add_argument(
"--static_attention_dtype ", default=None, type=str, help="Data type for static quantize attention. "
)

tuning = self.add_argument_group("Tuning Arguments")
tuning.add_argument(
Expand Down Expand Up @@ -606,8 +599,6 @@ def tune(args):
layer_config=layer_config,
model_dtype=args.model_dtype,
momentum=args.momentum,
static_kv_dtype=args.static_kv_dtype,
static_attention_dtype=args.static_attention_dtype,
)

model_name = args.model.rstrip("/")
Expand Down
13 changes: 1 addition & 12 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def __init__(
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
self.momentum = kwargs.pop("momentum", 0.0)
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
static_attention_dtype = kwargs.pop("static_attention_dtype", None)
model_dtype = kwargs.pop("model_dtype", None)
device = kwargs.pop("device", None)
if envs.AR_USE_MODELSCOPE:
Expand Down Expand Up @@ -357,11 +356,6 @@ def __init__(
if self.static_kv_dtype is not None:
logger.warning("The static kv is experimental and currently has limited support.")

# Attention static dtype
self.static_attention_dtype = static_attention_dtype
if self.static_attention_dtype is not None:
logger.warning("The static attention dtype is experimental and currently has limited support.")

self._set_amp_dtype()
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
if self.act_bits <= 8 and self.amp_dtype == torch.float16:
Expand Down Expand Up @@ -1010,12 +1004,7 @@ def quantize_and_save(
kwargs.pop("inplace", None)

# Perform model quantization
if self.static_attention_dtype is not None:
from auto_round.experimental.attention import attention_quant_ctx

with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype):
model, _ = self.quantize()
elif self.static_kv_dtype is not None:
if self.static_kv_dtype is not None:
from auto_round.experimental.kv_cache import kvcache_quant_context

with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype):
Expand Down
190 changes: 0 additions & 190 deletions auto_round/experimental/attention.py

This file was deleted.

77 changes: 62 additions & 15 deletions auto_round/experimental/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@
import torch
from transformers.cache_utils import DynamicCache

from auto_round.experimental.utils import (
is_attention_module,
normalize_static_kv_dtype,
per_tensor_fp8_qdq,
update_parameter_data,
)
from auto_round.utils import logger

__all__ = [
Expand Down Expand Up @@ -87,6 +81,13 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
return lst


def fp8_per_tensor_qdq(tensor):
from auto_round.data_type.fp8 import quant_fp8_sym

qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0)
return qdq_tensor, scale


class QuantizedKVParameterCache(DynamicCache):
"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Expand Down Expand Up @@ -172,8 +173,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_
assert kv_type == KVCacheScaleType.VALUE
scales = self.v_scales

qdq_tensor, scale = per_tensor_fp8_qdq(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0))
qdq_tensor, scale = fp8_per_tensor_qdq(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale)
return qdq_tensor


Expand All @@ -191,9 +192,13 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4
quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype)
setattr(module, "kv_cache", quantized_kv_cache)
logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
init_scale = torch.tensor([0.0], device=next(module.parameters()).device)
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value)
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value)


def is_attention_module(module: torch.nn.Module):
# FIXME: Handle this better.
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
)


def calibrate_kv_cache_input_hook(
Expand All @@ -204,6 +209,7 @@ def calibrate_kv_cache_input_hook(
kv_cache quantization. Will update the passed in
kv_cache to singleton QuantizedKVParameterCache.
"""
logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
kv_cache = getattr(module, "kv_cache")
# Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
# https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
Expand All @@ -215,14 +221,33 @@ def calibrate_kv_cache_input_hook(
return args, kwargs


def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str):
"""
Update the data of a parameter in a module.
If the parameter does not exist, it will be created.
"""
if hasattr(module, name):
param = getattr(module, name)
if isinstance(param, torch.nn.Parameter):
param.data = new_val
else:
module.register_parameter(name, torch.nn.Parameter(new_val))
else:
logger.warning(
"Parameter %s not found in module %s, creating new parameter."
% (name, module.__class__.__name__ + str(getattr(module, "layer_idx", "")))
)
module.register_parameter(name, torch.nn.Parameter(new_val))


def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
# logger.debug(
# "Calibrate kv_cache output hook for %s %s"
# % (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
# )
logger.debug(
"Calibrate kv_cache output hook for %s %s"
% (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
)
kv_cache = getattr(module, "kv_cache")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
Expand All @@ -236,6 +261,28 @@ def prep_attention_module_for_calibration(module: torch.nn.Module):
module.register_forward_hook(calibrate_kv_cache_output_hook)


def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype:
valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"]
valid_torch_dtype = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"fp8": torch.float8_e4m3fn,
"float8_e4m3fn": torch.float8_e4m3fn,
"float32": torch.float32,
"float": torch.float32, # Alias for float32
}
if static_kv_dtype in valid_dtype_name_lst:
new_dtype = valid_torch_dtype[static_kv_dtype]
elif static_kv_dtype in valid_torch_dtype.values():
new_dtype = static_kv_dtype
else:
raise ValueError(
f"Invalid static kv dtype: {static_kv_dtype}. "
f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}."
)
return new_dtype


@contextlib.contextmanager
def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn):
"""Context manager for FP8 KV cache quantization operations."""
Expand Down
4 changes: 2 additions & 2 deletions auto_round/experimental/qmodules/fp8_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor):

@torch.no_grad()
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
original_dtype = bf16_input.dtype

qdq_input = self.qdq_input(bf16_input)
qdq_weight = self.dequant_weight_online()
out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias)
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
return out
Loading