Skip to content

Commit

Permalink
move dynamic control to Quant Args
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Apr 25, 2024
1 parent 5082aad commit 7ca861b
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _maybe_calibrate_or_quantize(
return value

observer = getattr(module, f"{base_name}_observer")
if observer.DYNAMIC:
if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
scale, zero_point = observer(value)
else:
Expand Down
21 changes: 9 additions & 12 deletions src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,17 @@ def freeze_module_quantization(module: Module):
:param module: module to freeze quantization for
"""
if not getattr(module, "quantization_scheme", None):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return

# delete observers from module
observer_names = []
for submodule_name, submodule in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
if getattr(submodule, "DYNAMIC", False):
continue # do not delete dynamic observers

# delete any non-dynamic observers that belong directly to this module
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)
# delete observers from module if not dynamic
if scheme.input_activations and not scheme.input_activations.dynamic:
delattr(module, "input_observer")
if scheme.weights and not scheme.weights.dynamic:
delattr(module, "weight_observer")
if scheme.output_activations and not scheme.output_activations.dynamic:
delattr(module, "output_observer")

module.quantization_status = QuantizationStatus.FROZEN
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _initialize_scale_zero_point_observer(
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)

if observer.DYNAMIC:
if quantization_args.dynamic:
return # no need to register a scale and zero point for a dynamic observer

device = next(module.parameters()).device
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@
from .base import *
from .memoryless import *
from .min_max import *
from .dynamic import *
35 changes: 0 additions & 35 deletions src/compressed_tensors/quantization/observers/dynamic.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
__all__ = ["MemorylessObserver"]


@Observer.register("memoryless")
@Observer.register("memoryless", alias=["dynamic"])
class MemorylessObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
Implements a quantization observer that sets the scale and
zero point based on the latest observed value without tracking state
"""

Expand Down
6 changes: 6 additions & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class QuantizationArgs(BaseModel):
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
group_size: Optional[int] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
default="minmax",
description=(
Expand All @@ -82,4 +83,9 @@ def get_observer(self):
"""
from compressed_tensors.quantization.observers.base import Observer

if self.observer == "minmax" and self.dynamic:
# override defualt observer for dynamic, you never want minmax which
# keeps state across samples for dynamic
self.observer = "memoryless"

return Observer.load_from_registry(self.observer, quantization_args=self)

0 comments on commit 7ca861b

Please sign in to comment.