From cecabf8a76c133107e7afde80571c8b0d01b9a00 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:56:25 -0400 Subject: [PATCH 01/18] refactor Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 29 ++++++++++--------- .../quantization/utils/helpers.py | 16 ++++++++++ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..8c3c4867 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,23 +14,22 @@ import logging +from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization import ( +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, - KVCacheScaleType, QuantizationArgs, - QuantizationMetadata, - QuantizationScheme, - QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -54,17 +53,21 @@ _LOGGER = logging.getLogger(__name__) +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - Attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme. + attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme - Previously initialized scales and zero points will be removed from - module if they no longer apply to the scheme + apply to full model with `model.apply(initialize_module_for_quantization)` :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -77,8 +80,6 @@ def initialize_module_for_quantization( if scheme is None: return - QuantizationMetadata.clear_all_qparams(module) - if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c..d4428438 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,6 +33,7 @@ __all__ = [ + "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -235,6 +236,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max +def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa + """ + Checks the quantization status of a model. Assumes all modules in the model have + the same status, so only the first quantized model is checked. + + :param model: model to check quantization status for + :return: quantization status if the model is quantized, otherwise None + """ + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization From 7000b74ddcea2f6b453bab225885fca531e498d5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:59:55 -0400 Subject: [PATCH 02/18] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8c3c4867..4b896d37 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,22 +14,23 @@ import logging -from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, + KVCacheScaleType, QuantizationArgs, + QuantizationMetadata, + QuantizationScheme, + QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -53,21 +54,17 @@ _LOGGER = logging.getLogger(__name__) -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme + Attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme. - apply to full model with `model.apply(initialize_module_for_quantization)` + Previously initialized scales and zero points will be removed from + module if they no longer apply to the scheme :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -80,6 +77,8 @@ def initialize_module_for_quantization( if scheme is None: return + QuantizationMetadata.clear_all_qparams(module) + if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) From e91bb12e49d1b7c6e58116ac15b024b68b1249dd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 10:01:21 -0400 Subject: [PATCH 03/18] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/utils/helpers.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index d4428438..fccd677c 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,7 +33,6 @@ __all__ = [ - "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -236,21 +235,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa - """ - Checks the quantization status of a model. Assumes all modules in the model have - the same status, so only the first quantized model is checked. - - :param model: model to check quantization status for - :return: quantization status if the model is quantized, otherwise None - """ - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization From 50bc670056db5216d728cb14807791d2d881cb8f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:20:01 -0400 Subject: [PATCH 04/18] increase num of required observed dims Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..390b174a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -234,6 +234,12 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + if len(observed_shape) < 2: + raise ValueError("Attention quant requires at least 2 observed dimensions") + + expected_shape = (observed_shape[-2], 1) + else: assert False, f"Unknown strategy {strategy}" From 42255bc7bf55dd897e375485dd2c3d23ac5165be Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:30:24 -0400 Subject: [PATCH 05/18] remove attention head Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 390b174a..4b896d37 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -234,12 +234,6 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) - elif strategy == QuantizationStrategy.ATTN_HEAD: - if len(observed_shape) < 2: - raise ValueError("Attention quant requires at least 2 observed dimensions") - - expected_shape = (observed_shape[-2], 1) - else: assert False, f"Unknown strategy {strategy}" From a0b2caf345bc4c9898262e1597d3057854e880d0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:21:39 -0400 Subject: [PATCH 06/18] add tests Signed-off-by: Kyle Sayers --- tests/observer.py | 216 ++++++++++++++++++ .../lifecycle/test_static_lifecycle.py | 19 ++ 2 files changed, 235 insertions(+) create mode 100644 tests/observer.py diff --git a/tests/observer.py b/tests/observer.py new file mode 100644 index 00000000..b30d19fa --- /dev/null +++ b/tests/observer.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Tuple +from weakref import ref + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import ( + calculate_qparams, + generate_gparam, + strategy_cdiv, +) +from compressed_tensors.utils import getattr_chain + + +base_name_to_scheme_field = { + "q": "input_activations", + "k": "input_activations", + "v": "input_activations", + "input": "input_activations", + "weight": "weights", + "output": "output_activations", +} + + +class ObserverBase(torch.nn.Module): + def __init__(self, module: torch.nn.Module, base_name: str): + super().__init__() + self.parent = ref(module) + self.base_name = base_name + + self.scheme_field = base_name_to_scheme_field[base_name] + self.args: QuantizationArgs = getattr_chain( + module, f"quantization_scheme.{self.scheme_field}" + ) + + # used for moving averages and testing + self.min_vals = None + self.max_vals = None + + @abstractmethod + def get_min_max(self, observed: torch.Tensor): + ... + + def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + observed = flatten_for_quantization(observed, self.base_name, self.args) + + self.min_vals, self.max_vals = self.get_min_max(observed) + + scales, zero_points = calculate_qparams( + min_vals=self.min_vals, + max_vals=self.max_vals, + quantization_args=self.args, + global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None), + ) + + return scales, zero_points + + def get_global_scale(self, observed: torch.Tensor): + observed = observed.reshape((1, 1, -1)) # per tensor reshape + + min_vals, max_vals = self.get_min_max(observed) + + global_scale = generate_gparam(min_vals, max_vals) + + return global_scale + + +class MockMinMaxObserver(ObserverBase): + def __init__(self, module: torch.nn.Module, base_name: str): + super().__init__(module, base_name) + + def get_min_max(self, observed: torch.Tensor): + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) + + return min_vals, max_vals + + +class MockMovingMinMaxObserver(ObserverBase): + def __init__(self, module: torch.nn.Module, base_name: str): + super().__init__(module, base_name) + + self.averaging_constant = self.args.observer_kwargs.get( + "averaging_constant", 0.01 + ) + + def get_min_max(self, observed: torch.Tensor): + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) + + if self.min_vals is not None: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = torch.lerp(self.min_vals, min_vals, self.averaging_constant) + max_vals = torch.lerp(self.max_vals, max_vals, self.averaging_constant) + + return min_vals, max_vals + + +def flatten_for_quantization( + value: torch.Tensor, base_name: str, args: QuantizationArgs +) -> torch.Tensor: + if base_name == "weight": + return flatten_weight_for_quantization(value, args) + elif base_name in ("input", "output"): + return flatten_activation_for_quantization(value, args) + elif base_name in ("q", "k", "v"): + return flatten_attention_for_quantization(value, args) + else: + raise ValueError(f"Unknown quantization base name: {base_name}") + + +def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (1, 1, num_weight_elems) + return value.reshape((1, 1, -1)) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to weights") + + if args.strategy == QuantizationStrategy.CHANNEL: + # (1, num_rows, 1, num_cols) + return value.unsqueeze(-2).unsqueeze(0) + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # (1, num_rows, num_groups, group_size) + return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0) + + if args.strategy == QuantizationStrategy.BLOCK: + # (1, num_block_rows, num_block_cols, block_width * block_height) + block_height, block_width = args.block_structure + num_rows, num_cols = value.shape + num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy) + num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy) + return ( + value.reshape( + num_block_rows, + block_height, + num_block_cols, + block_width, + ) + .transpose(1, 2) + .flatten(-2, -1) + .unsqueeze(0) + ) + + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to weights") + + assert False, f"Unknown strategy {args.strategy}" + + +def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size * seq_len, 1, hidden_dim) + return value.reshape((-1, 1, value.size(-1))) + + if args.strategy == QuantizationStrategy.TOKEN: + # (batch_size, seq_len, hidden_dim) + # warning: token quantization uses `compute_dynamic_scales_and_zp` + return value.flatten(2, -1) + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to activations") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # (batch_size * seq_len, num_groups, group_size) + # warning: group activation quantization uses compute_dynamic_scales_and_zp + return value.flatten(0, 1).unflatten(-1, (-1, args.group_size)) + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to activations") + + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to linear acts") + + assert False, f"Unknown strategy {args.strategy}" + + +def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size, seq_len, num_heads, head_dim) + # (batch_size * seq_len, 1, num_heads * head_dim) + return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to attention") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + raise ValueError("Group quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, head_dim) + return value.flatten(0, 1).unsqueeze(-2) + + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 45ba602c..1392a75c 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -303,6 +303,25 @@ class MockAttention(torch.nn.Module): # group is not supported # tensor group is not supported # block is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="attn_head", + ), + torch.tensor([[0], [3]]), + torch.tensor([[8], [11]]), + torch.tensor( + [ + [ + [[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]], + [[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.16, + ), ], ) def test_static_attention_quantization( From 07170c0eed3f84b62eb147e96abe89c628aeb980 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:27:22 -0400 Subject: [PATCH 07/18] remove attn head Signed-off-by: Kyle Sayers --- tests/observer.py | 10 ---------- .../lifecycle/test_static_lifecycle.py | 19 ------------------- 2 files changed, 29 deletions(-) diff --git a/tests/observer.py b/tests/observer.py index b30d19fa..290153c0 100644 --- a/tests/observer.py +++ b/tests/observer.py @@ -158,9 +158,6 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs) .unsqueeze(0) ) - if args.strategy == QuantizationStrategy.ATTN_HEAD: - raise ValueError("attention head quantization cannot be applied to weights") - assert False, f"Unknown strategy {args.strategy}" @@ -185,9 +182,6 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") - if args.strategy == QuantizationStrategy.ATTN_HEAD: - raise ValueError("attention head quantization cannot be applied to linear acts") - assert False, f"Unknown strategy {args.strategy}" @@ -209,8 +203,4 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") - if args.strategy == QuantizationStrategy.ATTN_HEAD: - # (batch_size * seq_len, num_heads, 1, head_dim) - return value.flatten(0, 1).unsqueeze(-2) - assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 1392a75c..45ba602c 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -303,25 +303,6 @@ class MockAttention(torch.nn.Module): # group is not supported # tensor group is not supported # block is not supported - ( - QuantizationArgs( - num_bits=4, - type="int", - symmetric=True, - strategy="attn_head", - ), - torch.tensor([[0], [3]]), - torch.tensor([[8], [11]]), - torch.tensor( - [ - [ - [[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]], - [[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]], - ] - ] - ), - 0.16, - ), ], ) def test_static_attention_quantization( From 52be0f7a1e66c138d1347617954998346a7f9b11 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 8 Oct 2025 12:12:18 -0400 Subject: [PATCH 08/18] simplify Signed-off-by: Kyle Sayers --- tests/observer.py | 206 ---------------------------------------------- 1 file changed, 206 deletions(-) delete mode 100644 tests/observer.py diff --git a/tests/observer.py b/tests/observer.py deleted file mode 100644 index 290153c0..00000000 --- a/tests/observer.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import abstractmethod -from typing import Tuple -from weakref import ref - -import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy -from compressed_tensors.quantization.utils import ( - calculate_qparams, - generate_gparam, - strategy_cdiv, -) -from compressed_tensors.utils import getattr_chain - - -base_name_to_scheme_field = { - "q": "input_activations", - "k": "input_activations", - "v": "input_activations", - "input": "input_activations", - "weight": "weights", - "output": "output_activations", -} - - -class ObserverBase(torch.nn.Module): - def __init__(self, module: torch.nn.Module, base_name: str): - super().__init__() - self.parent = ref(module) - self.base_name = base_name - - self.scheme_field = base_name_to_scheme_field[base_name] - self.args: QuantizationArgs = getattr_chain( - module, f"quantization_scheme.{self.scheme_field}" - ) - - # used for moving averages and testing - self.min_vals = None - self.max_vals = None - - @abstractmethod - def get_min_max(self, observed: torch.Tensor): - ... - - def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - observed = flatten_for_quantization(observed, self.base_name, self.args) - - self.min_vals, self.max_vals = self.get_min_max(observed) - - scales, zero_points = calculate_qparams( - min_vals=self.min_vals, - max_vals=self.max_vals, - quantization_args=self.args, - global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None), - ) - - return scales, zero_points - - def get_global_scale(self, observed: torch.Tensor): - observed = observed.reshape((1, 1, -1)) # per tensor reshape - - min_vals, max_vals = self.get_min_max(observed) - - global_scale = generate_gparam(min_vals, max_vals) - - return global_scale - - -class MockMinMaxObserver(ObserverBase): - def __init__(self, module: torch.nn.Module, base_name: str): - super().__init__(module, base_name) - - def get_min_max(self, observed: torch.Tensor): - min_vals = torch.amin(observed, dim=(0, -1)) - max_vals = torch.amax(observed, dim=(0, -1)) - - return min_vals, max_vals - - -class MockMovingMinMaxObserver(ObserverBase): - def __init__(self, module: torch.nn.Module, base_name: str): - super().__init__(module, base_name) - - self.averaging_constant = self.args.observer_kwargs.get( - "averaging_constant", 0.01 - ) - - def get_min_max(self, observed: torch.Tensor): - min_vals = torch.amin(observed, dim=(0, -1)) - max_vals = torch.amax(observed, dim=(0, -1)) - - if self.min_vals is not None: - # FUTURE: consider scaling by num observations (first dim) - # rather than reducing by first dim - min_vals = torch.lerp(self.min_vals, min_vals, self.averaging_constant) - max_vals = torch.lerp(self.max_vals, max_vals, self.averaging_constant) - - return min_vals, max_vals - - -def flatten_for_quantization( - value: torch.Tensor, base_name: str, args: QuantizationArgs -) -> torch.Tensor: - if base_name == "weight": - return flatten_weight_for_quantization(value, args) - elif base_name in ("input", "output"): - return flatten_activation_for_quantization(value, args) - elif base_name in ("q", "k", "v"): - return flatten_attention_for_quantization(value, args) - else: - raise ValueError(f"Unknown quantization base name: {base_name}") - - -def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs): - if args.strategy == QuantizationStrategy.TENSOR: - # (1, 1, num_weight_elems) - return value.reshape((1, 1, -1)) - - if args.strategy == QuantizationStrategy.TOKEN: - raise ValueError("Token quantization cannot be applied to weights") - - if args.strategy == QuantizationStrategy.CHANNEL: - # (1, num_rows, 1, num_cols) - return value.unsqueeze(-2).unsqueeze(0) - - if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - # (1, num_rows, num_groups, group_size) - return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0) - - if args.strategy == QuantizationStrategy.BLOCK: - # (1, num_block_rows, num_block_cols, block_width * block_height) - block_height, block_width = args.block_structure - num_rows, num_cols = value.shape - num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy) - num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy) - return ( - value.reshape( - num_block_rows, - block_height, - num_block_cols, - block_width, - ) - .transpose(1, 2) - .flatten(-2, -1) - .unsqueeze(0) - ) - - assert False, f"Unknown strategy {args.strategy}" - - -def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs): - if args.strategy == QuantizationStrategy.TENSOR: - # (batch_size * seq_len, 1, hidden_dim) - return value.reshape((-1, 1, value.size(-1))) - - if args.strategy == QuantizationStrategy.TOKEN: - # (batch_size, seq_len, hidden_dim) - # warning: token quantization uses `compute_dynamic_scales_and_zp` - return value.flatten(2, -1) - - if args.strategy == QuantizationStrategy.CHANNEL: - raise ValueError("Channel quantization cannot be applied to activations") - - if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - # (batch_size * seq_len, num_groups, group_size) - # warning: group activation quantization uses compute_dynamic_scales_and_zp - return value.flatten(0, 1).unflatten(-1, (-1, args.group_size)) - - if args.strategy == QuantizationStrategy.BLOCK: - raise ValueError("Block quantization cannot be applied to activations") - - assert False, f"Unknown strategy {args.strategy}" - - -def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs): - if args.strategy == QuantizationStrategy.TENSOR: - # (batch_size, seq_len, num_heads, head_dim) - # (batch_size * seq_len, 1, num_heads * head_dim) - return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) - - if args.strategy == QuantizationStrategy.TOKEN: - raise ValueError("Token quantization cannot be applied to attention") - - if args.strategy == QuantizationStrategy.CHANNEL: - raise ValueError("Channel quantization cannot be applied to attention") - - if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - raise ValueError("Group quantization cannot be applied to attention") - - if args.strategy == QuantizationStrategy.BLOCK: - raise ValueError("Block quantization cannot be applied to attention") - - assert False, f"Unknown strategy {args.strategy}" From ad09ed8b9007a38ac40a04a28a3614553962c696 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:56:25 -0400 Subject: [PATCH 09/18] refactor Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 29 ++++++++++--------- .../quantization/utils/helpers.py | 16 ++++++++++ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..8c3c4867 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,23 +14,22 @@ import logging +from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization import ( +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, - KVCacheScaleType, QuantizationArgs, - QuantizationMetadata, - QuantizationScheme, - QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -54,17 +53,21 @@ _LOGGER = logging.getLogger(__name__) +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - Attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme. + attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme - Previously initialized scales and zero points will be removed from - module if they no longer apply to the scheme + apply to full model with `model.apply(initialize_module_for_quantization)` :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -77,8 +80,6 @@ def initialize_module_for_quantization( if scheme is None: return - QuantizationMetadata.clear_all_qparams(module) - if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c..d4428438 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,6 +33,7 @@ __all__ = [ + "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -235,6 +236,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max +def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa + """ + Checks the quantization status of a model. Assumes all modules in the model have + the same status, so only the first quantized model is checked. + + :param model: model to check quantization status for + :return: quantization status if the model is quantized, otherwise None + """ + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization From 19a505114a775fb66e82afc3256cf92ceded9826 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:59:55 -0400 Subject: [PATCH 10/18] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8c3c4867..4b896d37 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,22 +14,23 @@ import logging -from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, + KVCacheScaleType, QuantizationArgs, + QuantizationMetadata, + QuantizationScheme, + QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -53,21 +54,17 @@ _LOGGER = logging.getLogger(__name__) -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme + Attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme. - apply to full model with `model.apply(initialize_module_for_quantization)` + Previously initialized scales and zero points will be removed from + module if they no longer apply to the scheme :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -80,6 +77,8 @@ def initialize_module_for_quantization( if scheme is None: return + QuantizationMetadata.clear_all_qparams(module) + if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) From 370e2ca9a4d534f13dc0681cc12ab5144741e345 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:20:01 -0400 Subject: [PATCH 11/18] increase num of required observed dims Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..390b174a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -234,6 +234,12 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + if len(observed_shape) < 2: + raise ValueError("Attention quant requires at least 2 observed dimensions") + + expected_shape = (observed_shape[-2], 1) + else: assert False, f"Unknown strategy {strategy}" From bb3ddaf343e6f241941d8eee63a0263a315a07ff Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:03:36 -0400 Subject: [PATCH 12/18] add tests Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 1c27dcd3..00631ace 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum): From 29963037c7fe86f04abb80a021e3b733195e93f2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:18:45 -0400 Subject: [PATCH 13/18] add tests for attn head Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_scheme.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c..1e3e089d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.TENSOR, QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.ATTN_HEAD, ): if ( inputs.strategy == QuantizationStrategy.GROUP From ed5a2551c910c6126cbcac387ef441893b7090d4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:28:05 -0400 Subject: [PATCH 14/18] add tests Signed-off-by: Kyle Sayers --- tests/mock_observer.py | 10 ++++++++++ .../lifecycle/test_static_lifecycle.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/mock_observer.py b/tests/mock_observer.py index 4563061c..ebed99f5 100644 --- a/tests/mock_observer.py +++ b/tests/mock_observer.py @@ -110,6 +110,9 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs) .unsqueeze(0) ) + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to weights") + assert False, f"Unknown strategy {args.strategy}" @@ -134,6 +137,9 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to linear acts") + assert False, f"Unknown strategy {args.strategy}" @@ -155,4 +161,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, head_dim) + return value.flatten(0, 1).unsqueeze(-2) + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 45ba602c..1392a75c 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -303,6 +303,25 @@ class MockAttention(torch.nn.Module): # group is not supported # tensor group is not supported # block is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="attn_head", + ), + torch.tensor([[0], [3]]), + torch.tensor([[8], [11]]), + torch.tensor( + [ + [ + [[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]], + [[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.16, + ), ], ) def test_static_attention_quantization( From bfac4753a7dd18c841bb442abe37129be3d1a32e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 18:10:31 -0400 Subject: [PATCH 15/18] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/utils/helpers.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index d4428438..fccd677c 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,7 +33,6 @@ __all__ = [ - "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -236,21 +235,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa - """ - Checks the quantization status of a model. Assumes all modules in the model have - the same status, so only the first quantized model is checked. - - :param model: model to check quantization status for - :return: quantization status if the model is quantized, otherwise None - """ - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization From bf1b9babcc343e216c522ec761b21d20e2bfc757 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 8 Oct 2025 14:44:32 -0400 Subject: [PATCH 16/18] fix shapes Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 2 +- .../quantization/lifecycle/initialize.py | 13 +++++++------ tests/mock_observer.py | 4 ++-- .../lifecycle/test_static_lifecycle.py | 16 +++++++++------- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2e539b07..176b2f52 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -330,7 +330,7 @@ def _process_quantization( inv_perm = torch.argsort(perm) output = output.index_select(-1, inv_perm) - else: # covers channel, token and tensor strategies + else: # covers tensor, channel, token, and attn_head strategies if do_quantize: output = _quantize( x=x, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 390b174a..50757adc 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,7 +14,7 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from compressed_tensors.quantization import ( @@ -152,7 +152,7 @@ def initialize_qparams( module: Module, base_name: str, quantization_args: QuantizationArgs, - observed_shape: Tuple[int], + observed_shape: Tuple[Union[int, None]], observed_dtype: torch.dtype, force_zero_point: bool = True, ): @@ -199,7 +199,7 @@ def initialize_qparams( expected_shape = (1,) elif strategy == QuantizationStrategy.TOKEN: - expected_shape = (1, 1) + raise ValueError("Cannot perform static token quantization") elif strategy == QuantizationStrategy.CHANNEL: if len(observed_shape) < 2: @@ -235,10 +235,11 @@ def initialize_qparams( expected_shape = (num_rows, num_cols) elif strategy == QuantizationStrategy.ATTN_HEAD: - if len(observed_shape) < 2: - raise ValueError("Attention quant requires at least 2 observed dimensions") + # (batch_size, num_attention_heads, seq_len, head_dim) + if len(observed_shape) < 3: + raise ValueError("Attention quant requires at least 3 observed dimensions") - expected_shape = (observed_shape[-2], 1) + expected_shape = (observed_shape[-3], 1, 1) else: assert False, f"Unknown strategy {strategy}" diff --git a/tests/mock_observer.py b/tests/mock_observer.py index ebed99f5..6a42bdc1 100644 --- a/tests/mock_observer.py +++ b/tests/mock_observer.py @@ -162,7 +162,7 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr raise ValueError("Block quantization cannot be applied to attention") if args.strategy == QuantizationStrategy.ATTN_HEAD: - # (batch_size * seq_len, num_heads, 1, head_dim) - return value.flatten(0, 1).unsqueeze(-2) + # (batch_size * seq_len, num_heads, 1, 1, head_dim) + return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2) assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 1392a75c..73c9a11d 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -310,17 +310,17 @@ class MockAttention(torch.nn.Module): symmetric=True, strategy="attn_head", ), - torch.tensor([[0], [3]]), - torch.tensor([[8], [11]]), + torch.tensor([[[0.0]], [[6.0]]]), + torch.tensor([[[5.0]], [[11.0]]]), torch.tensor( [ [ - [[0.0000, 1.0703, 2.1406], [2.9375, 4.4062, 4.4062]], - [[6.4375, 7.5000, 7.5000], [8.8125, 10.2500, 10.2500]], + [[0.0000, 1.3359, 2.0000], [2.6719, 4.0000, 4.6875]], + [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], ] ] ), - 0.16, + 0.13, ), ], ) @@ -335,7 +335,7 @@ def test_static_attention_quantization( [ 9., 10., 11.]]]]) """ # set up activation (and identity weight) - batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3 + batch_size, num_heads, seq_len, head_dim = 1, 2, 2, 3 input = torch.arange( (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 ).reshape((batch_size, seq_len, num_heads, head_dim)) @@ -344,7 +344,7 @@ def test_static_attention_quantization( # initialize quantization parameters scheme = QuantizationScheme(targets=[], input_activations=args) initialize_qparams( - attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16 + attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16 ) attention.quantization_scheme = scheme attention.quantization_status = QuantizationStatus.INITIALIZED @@ -366,5 +366,7 @@ def test_static_attention_quantization( assert torch.equal(attention.k_observer.max_vals, exp_max_val) # check forward pass + print(output) + print(torch.nn.functional.mse_loss(output, input)) assert torch.allclose(output, exp_quant.to(output.dtype)) assert torch.nn.functional.mse_loss(output, input) <= exp_loss From e3f24d4dc0a55655d6d4572b889bd1cc0ff33c11 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 8 Oct 2025 15:11:45 -0400 Subject: [PATCH 17/18] fix shapes Signed-off-by: Kyle Sayers --- tests/mock_observer.py | 9 +++- .../lifecycle/test_static_lifecycle.py | 54 ++++++++++++------- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/tests/mock_observer.py b/tests/mock_observer.py index 6a42bdc1..9b3d0e72 100644 --- a/tests/mock_observer.py +++ b/tests/mock_observer.py @@ -77,6 +77,8 @@ def flatten_for_quantization( def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (num_rows, num_cols) + if args.strategy == QuantizationStrategy.TENSOR: # (1, 1, num_weight_elems) return value.reshape((1, 1, -1)) @@ -117,6 +119,8 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs) def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, seq_len, hidden_dim) + if args.strategy == QuantizationStrategy.TENSOR: # (batch_size * seq_len, 1, hidden_dim) return value.reshape((-1, 1, value.size(-1))) @@ -144,10 +148,11 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, num_heads, seq_len, head_dim) + if args.strategy == QuantizationStrategy.TENSOR: - # (batch_size, seq_len, num_heads, head_dim) # (batch_size * seq_len, 1, num_heads * head_dim) - return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2) if args.strategy == QuantizationStrategy.TOKEN: raise ValueError("Token quantization cannot be applied to attention") diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 73c9a11d..36a857d1 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -287,16 +287,24 @@ class MockAttention(torch.nn.Module): strategy="tensor", ), torch.tensor([0.0]), - torch.tensor([11.0]), + torch.tensor([23.0]), torch.tensor( [ [ - [[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]], - [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + [ + [0.0000, 0.0000, 3.0625, 3.0625], + [3.0625, 6.1250, 6.1250, 6.1250], + [9.1875, 9.1875, 9.1875, 12.2500], + ], + [ + [12.2500, 12.2500, 15.3125, 15.3125], + [15.3125, 18.3750, 18.3750, 18.3750], + [21.5000, 21.5000, 21.5000, 21.5000], + ], ] ] ), - 0.19, + 0.81, ), # static token is not supported # channel is not supported @@ -310,17 +318,25 @@ class MockAttention(torch.nn.Module): symmetric=True, strategy="attn_head", ), - torch.tensor([[[0.0]], [[6.0]]]), - torch.tensor([[[5.0]], [[11.0]]]), + torch.tensor([[[0.0]], [[12.0]]]), + torch.tensor([[[11.0]], [[23.0]]]), torch.tensor( [ [ - [[0.0000, 1.3359, 2.0000], [2.6719, 4.0000, 4.6875]], - [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + [ + [0.0000, 1.4688, 1.4688, 2.9375], + [4.4062, 4.4062, 5.8750, 7.3438], + [7.3438, 8.8125, 10.2500, 10.2500], + ], + [ + [12.2500, 12.2500, 15.3125, 15.3125], + [15.3125, 18.3750, 18.3750, 18.3750], + [21.5000, 21.5000, 21.5000, 21.5000], + ], ] ] ), - 0.13, + 0.55, ), ], ) @@ -328,17 +344,19 @@ def test_static_attention_quantization( args, exp_min_val, exp_max_val, exp_quant, exp_loss ): """ - input = tensor([[[[ 0., 1., 2.], - [ 3., 4., 5.]], + input = tensor([[[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], - [[ 6., 7., 8.], - [ 9., 10., 11.]]]]) + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]]) """ - # set up activation (and identity weight) - batch_size, num_heads, seq_len, head_dim = 1, 2, 2, 3 + # set up attention + batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4 input = torch.arange( - (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 - ).reshape((batch_size, seq_len, num_heads, head_dim)) + (batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16 + ).reshape((batch_size, num_heads, seq_len, head_dim)) attention = MockAttention() # initialize quantization parameters @@ -366,7 +384,5 @@ def test_static_attention_quantization( assert torch.equal(attention.k_observer.max_vals, exp_max_val) # check forward pass - print(output) - print(torch.nn.functional.mse_loss(output, input)) assert torch.allclose(output, exp_quant.to(output.dtype)) assert torch.nn.functional.mse_loss(output, input) <= exp_loss From 85f96cc317cf626e69003eb6f86b21665b74ede0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 9 Oct 2025 10:26:42 -0400 Subject: [PATCH 18/18] revert Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 50757adc..1324e83e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -199,7 +199,7 @@ def initialize_qparams( expected_shape = (1,) elif strategy == QuantizationStrategy.TOKEN: - raise ValueError("Cannot perform static token quantization") + expected_shape = (1, 1) elif strategy == QuantizationStrategy.CHANNEL: if len(observed_shape) < 2: