diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2e539b07..176b2f52 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -330,7 +330,7 @@ def _process_quantization( inv_perm = torch.argsort(perm) output = output.index_select(-1, inv_perm) - else: # covers channel, token and tensor strategies + else: # covers tensor, channel, token, and attn_head strategies if do_quantize: output = _quantize( x=x, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4b896d37..1324e83e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,7 +14,7 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from compressed_tensors.quantization import ( @@ -152,7 +152,7 @@ def initialize_qparams( module: Module, base_name: str, quantization_args: QuantizationArgs, - observed_shape: Tuple[int], + observed_shape: Tuple[Union[int, None]], observed_dtype: torch.dtype, force_zero_point: bool = True, ): @@ -234,6 +234,13 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size, num_attention_heads, seq_len, head_dim) + if len(observed_shape) < 3: + raise ValueError("Attention quant requires at least 3 observed dimensions") + + expected_shape = (observed_shape[-3], 1, 1) + else: assert False, f"Unknown strategy {strategy}" diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 1c27dcd3..00631ace 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum): diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c..1e3e089d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": QuantizationStrategy.TENSOR, QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.ATTN_HEAD, ): if ( inputs.strategy == QuantizationStrategy.GROUP diff --git a/tests/mock_observer.py b/tests/mock_observer.py index 4563061c..9b3d0e72 100644 --- a/tests/mock_observer.py +++ b/tests/mock_observer.py @@ -77,6 +77,8 @@ def flatten_for_quantization( def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (num_rows, num_cols) + if args.strategy == QuantizationStrategy.TENSOR: # (1, 1, num_weight_elems) return value.reshape((1, 1, -1)) @@ -110,10 +112,15 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs) .unsqueeze(0) ) + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to weights") + assert False, f"Unknown strategy {args.strategy}" def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, seq_len, hidden_dim) + if args.strategy == QuantizationStrategy.TENSOR: # (batch_size * seq_len, 1, hidden_dim) return value.reshape((-1, 1, value.size(-1))) @@ -134,14 +141,18 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to activations") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + raise ValueError("attention head quantization cannot be applied to linear acts") + assert False, f"Unknown strategy {args.strategy}" def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs): + # value.shape = (batch_size, num_heads, seq_len, head_dim) + if args.strategy == QuantizationStrategy.TENSOR: - # (batch_size, seq_len, num_heads, head_dim) # (batch_size * seq_len, 1, num_heads * head_dim) - return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2) if args.strategy == QuantizationStrategy.TOKEN: raise ValueError("Token quantization cannot be applied to attention") @@ -155,4 +166,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr if args.strategy == QuantizationStrategy.BLOCK: raise ValueError("Block quantization cannot be applied to attention") + if args.strategy == QuantizationStrategy.ATTN_HEAD: + # (batch_size * seq_len, num_heads, 1, 1, head_dim) + return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2) + assert False, f"Unknown strategy {args.strategy}" diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 45ba602c..36a857d1 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -287,45 +287,82 @@ class MockAttention(torch.nn.Module): strategy="tensor", ), torch.tensor([0.0]), - torch.tensor([11.0]), + torch.tensor([23.0]), torch.tensor( [ [ - [[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]], - [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + [ + [0.0000, 0.0000, 3.0625, 3.0625], + [3.0625, 6.1250, 6.1250, 6.1250], + [9.1875, 9.1875, 9.1875, 12.2500], + ], + [ + [12.2500, 12.2500, 15.3125, 15.3125], + [15.3125, 18.3750, 18.3750, 18.3750], + [21.5000, 21.5000, 21.5000, 21.5000], + ], ] ] ), - 0.19, + 0.81, ), # static token is not supported # channel is not supported # group is not supported # tensor group is not supported # block is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="attn_head", + ), + torch.tensor([[[0.0]], [[12.0]]]), + torch.tensor([[[11.0]], [[23.0]]]), + torch.tensor( + [ + [ + [ + [0.0000, 1.4688, 1.4688, 2.9375], + [4.4062, 4.4062, 5.8750, 7.3438], + [7.3438, 8.8125, 10.2500, 10.2500], + ], + [ + [12.2500, 12.2500, 15.3125, 15.3125], + [15.3125, 18.3750, 18.3750, 18.3750], + [21.5000, 21.5000, 21.5000, 21.5000], + ], + ] + ] + ), + 0.55, + ), ], ) def test_static_attention_quantization( args, exp_min_val, exp_max_val, exp_quant, exp_loss ): """ - input = tensor([[[[ 0., 1., 2.], - [ 3., 4., 5.]], + input = tensor([[[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], - [[ 6., 7., 8.], - [ 9., 10., 11.]]]]) + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]]) """ - # set up activation (and identity weight) - batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3 + # set up attention + batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4 input = torch.arange( - (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 - ).reshape((batch_size, seq_len, num_heads, head_dim)) + (batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16 + ).reshape((batch_size, num_heads, seq_len, head_dim)) attention = MockAttention() # initialize quantization parameters scheme = QuantizationScheme(targets=[], input_activations=args) initialize_qparams( - attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16 + attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16 ) attention.quantization_scheme = scheme attention.quantization_status = QuantizationStatus.INITIALIZED