Skip to content
Merged
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _process_quantization(
inv_perm = torch.argsort(perm)
output = output.index_select(-1, inv_perm)

else: # covers channel, token and tensor strategies
else: # covers tensor, channel, token, and attn_head strategies
if do_quantize:
output = _quantize(
x=x,
Expand Down
11 changes: 9 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


import logging
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -152,7 +152,7 @@ def initialize_qparams(
module: Module,
base_name: str,
quantization_args: QuantizationArgs,
observed_shape: Tuple[int],
observed_shape: Tuple[Union[int, None]],
observed_dtype: torch.dtype,
force_zero_point: bool = True,
):
Expand Down Expand Up @@ -234,6 +234,13 @@ def initialize_qparams(
num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
expected_shape = (num_rows, num_cols)

elif strategy == QuantizationStrategy.ATTN_HEAD:
# (batch_size, num_attention_heads, seq_len, head_dim)
if len(observed_shape) < 3:
raise ValueError("Attention quant requires at least 3 observed dimensions")

expected_shape = (observed_shape[-3], 1, 1)

else:
assert False, f"Unknown strategy {strategy}"

Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum):
BLOCK = "block"
TOKEN = "token"
TENSOR_GROUP = "tensor_group"
ATTN_HEAD = "attn_head"


class DynamicType(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
QuantizationStrategy.TENSOR,
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.ATTN_HEAD,
):
if (
inputs.strategy == QuantizationStrategy.GROUP
Expand Down
19 changes: 17 additions & 2 deletions tests/mock_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def flatten_for_quantization(


def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
# value.shape = (num_rows, num_cols)

if args.strategy == QuantizationStrategy.TENSOR:
# (1, 1, num_weight_elems)
return value.reshape((1, 1, -1))
Expand Down Expand Up @@ -110,10 +112,15 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
.unsqueeze(0)
)

if args.strategy == QuantizationStrategy.ATTN_HEAD:
raise ValueError("attention head quantization cannot be applied to weights")

assert False, f"Unknown strategy {args.strategy}"


def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
# value.shape = (batch_size, seq_len, hidden_dim)

if args.strategy == QuantizationStrategy.TENSOR:
# (batch_size * seq_len, 1, hidden_dim)
return value.reshape((-1, 1, value.size(-1)))
Expand All @@ -134,14 +141,18 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
if args.strategy == QuantizationStrategy.BLOCK:
raise ValueError("Block quantization cannot be applied to activations")

if args.strategy == QuantizationStrategy.ATTN_HEAD:
raise ValueError("attention head quantization cannot be applied to linear acts")

assert False, f"Unknown strategy {args.strategy}"


def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
# value.shape = (batch_size, num_heads, seq_len, head_dim)

if args.strategy == QuantizationStrategy.TENSOR:
# (batch_size, seq_len, num_heads, head_dim)
# (batch_size * seq_len, 1, num_heads * head_dim)
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2)

if args.strategy == QuantizationStrategy.TOKEN:
raise ValueError("Token quantization cannot be applied to attention")
Expand All @@ -155,4 +166,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
if args.strategy == QuantizationStrategy.BLOCK:
raise ValueError("Block quantization cannot be applied to attention")

if args.strategy == QuantizationStrategy.ATTN_HEAD:
# (batch_size * seq_len, num_heads, 1, 1, head_dim)
return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2)

assert False, f"Unknown strategy {args.strategy}"
63 changes: 50 additions & 13 deletions tests/test_quantization/lifecycle/test_static_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,45 +287,82 @@ class MockAttention(torch.nn.Module):
strategy="tensor",
),
torch.tensor([0.0]),
torch.tensor([11.0]),
torch.tensor([23.0]),
torch.tensor(
[
[
[[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]],
[[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]],
[
[0.0000, 0.0000, 3.0625, 3.0625],
[3.0625, 6.1250, 6.1250, 6.1250],
[9.1875, 9.1875, 9.1875, 12.2500],
],
[
[12.2500, 12.2500, 15.3125, 15.3125],
[15.3125, 18.3750, 18.3750, 18.3750],
[21.5000, 21.5000, 21.5000, 21.5000],
],
]
]
),
0.19,
0.81,
),
# static token is not supported
# channel is not supported
# group is not supported
# tensor group is not supported
# block is not supported
(
QuantizationArgs(
num_bits=4,
type="int",
symmetric=True,
strategy="attn_head",
),
torch.tensor([[[0.0]], [[12.0]]]),
torch.tensor([[[11.0]], [[23.0]]]),
torch.tensor(
[
[
[
[0.0000, 1.4688, 1.4688, 2.9375],
[4.4062, 4.4062, 5.8750, 7.3438],
[7.3438, 8.8125, 10.2500, 10.2500],
],
[
[12.2500, 12.2500, 15.3125, 15.3125],
[15.3125, 18.3750, 18.3750, 18.3750],
[21.5000, 21.5000, 21.5000, 21.5000],
],
]
]
),
0.55,
),
],
)
def test_static_attention_quantization(
args, exp_min_val, exp_max_val, exp_quant, exp_loss
):
"""
input = tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.]],
input = tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],

[[ 6., 7., 8.],
[ 9., 10., 11.]]]])
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]]])
"""
# set up activation (and identity weight)
batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3
# set up attention
batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4
input = torch.arange(
(batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16
).reshape((batch_size, seq_len, num_heads, head_dim))
(batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16
).reshape((batch_size, num_heads, seq_len, head_dim))
attention = MockAttention()

# initialize quantization parameters
scheme = QuantizationScheme(targets=[], input_activations=args)
initialize_qparams(
attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16
attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16
)
attention.quantization_scheme = scheme
attention.quantization_status = QuantizationStatus.INITIALIZED
Expand Down