Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] per-token dynamic observer #24

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,22 @@ def _maybe_calibrate_or_quantize(
}:
return value

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)
observer = getattr(module, f"{base_name}_observer")
if observer.DYNAMIC:
# dynamic quantization - get scale and zero point directly from observer
scale, zero_point = observer(value)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# calibration mode - get new quant params from observer
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
7 changes: 5 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def freeze_module_quantization(module: Module):

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
for submodule_name, submodule in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
if getattr(submodule, "DYNAMIC", False):
continue # do not delete dynamic observers

# delete any non-dynamic observers that belong directly to this module
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)
Expand Down
11 changes: 7 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def initialize_module_for_quantization(
def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)

if observer.DYNAMIC:
return # no need to register a scale and zero point for a dynamic observer

device = next(module.parameters()).device

# initializes empty scale and zero point parameters for the module
Expand All @@ -88,7 +95,3 @@ def _initialize_scale_zero_point_observer(
torch.empty(0, device=device, dtype=int), requires_grad=False
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
2 changes: 2 additions & 0 deletions src/compressed_tensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@
from .base import *
from .memoryless import *
from .min_max import *
from .dynamic import *
from .per_token import *
3 changes: 3 additions & 0 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class Observer(Module, RegistryMixin):
pair
"""

# child classes should set to True if they are meant to be used as dynamic
DYNAMIC = False

def __init__(self, quantization_args: QuantizationArgs):
self.quantization_args: QuantizationArgs = quantization_args
super().__init__()
Expand Down
35 changes: 35 additions & 0 deletions src/compressed_tensors/quantization/observers/dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from compressed_tensors.quantization.observers.base import Observer
from compressed_tensors.quantization.observers.memoryless import MemorylessObserver


__all__ = ["DynamicObserver"]


@Observer.register("dynamic")
class DynamicObserver(MemorylessObserver):
"""
Values targted for a dyanmic observer do not require calibration,
this observer will persist in the model through the lifecycle, calculating
the quantization parameters on the fly for each observed Tensor.

This base dynamic observer uses the `calculate_qparams` from MemorylessObserver
where each scale and zero point is based solely on the currently observed
Tensor.
"""

DYNAMIC = False
70 changes: 70 additions & 0 deletions src/compressed_tensors/quantization/observers/per_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
from compressed_tensors.quantization.observers.helpers import calculate_qparams
from compressed_tensors.quantization.quant_args import QuantizationArgs
from torch import FloatTensor, IntTensor, Tensor


__all__ = ["PerTokenObserver"]


@Observer.register("per_token", alias="per_token_dynamic")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are per token observers always dynamic?

class PerTokenObserver(Observer):
"""
Values targted for a dyanmic observer do not require calibration,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spelling

this observer will persist in the model through the lifecycle, calculating
the quantization parameters on the fly for each observed Tensor.

This base dynamic observer uses the `calculate_qparams` from MemorylessObserver
where each scale and zero point is based solely on the currently observed
Tensor.

:param axis: axis that token dimension is expected to be in
"""

def __init__(self, quantization_args: QuantizationArgs, axis: int = 1):
super().__init__(quantization_args=quantization_args)

self.axis = 1

DYNAMIC = True

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
:param observed: observed tensor to calculate quantization parameters for
:return: tuple of scale and zero point derived from the observed tensor
"""
# reduce every dimension except token dimension
reduce_dims = [idx for idx in range(observed.dim()) if idx != self.axis]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not reduce along batch as well


# return shape will be [1, ..., num_tokens, 1, ...] with same num dims
min_vals = observed.amin(dim=reduce_dims, keepdim=True)
max_vals = observed.amax(dim=reduce_dims, keepdim=True)

# ensure zero is in the range
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))

# returned shape will match the min/max vals shape
# since keepdim=True, the reduced dims will have their dims set to 1
# so scales and zero points should broadcast correctly along the
# token axis
# TODO: add test for the broadcast mentioned above
return calculate_qparams(min_vals, max_vals, self.quantization_args)