Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions src/sparseml/pytorch/optim/modifier_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
add_quant_dequant,
configure_module_default_qconfigs,
configure_module_qat_wrappers,
fix_observer_quant_range,
fuse_module_conv_bn_relus,
get_qat_qconfig,
prepare_embeddings_qat,
Expand Down Expand Up @@ -85,6 +86,7 @@ class QuantizationModifier(ScheduledModifier):
| disable_quantization_observer_epoch: 2.0
| freeze_bn_stats_epoch: 3.0
| reduce_range: False
| activation_bits: False

:param start_epoch: The epoch to start the modifier at
:param submodules: List of submodule names to perform QAT on. Leave None to quantize
Expand Down Expand Up @@ -114,10 +116,16 @@ class QuantizationModifier(ScheduledModifier):
transformer based models such as BERT where the quantized MatMul outputs
are kept at 32 bits of precision and fake quantizing the outputs harm training
recovery. Default is True
:param activation_bits: Number of bits to use for setting quant min/max values for
activations. Default is None, which will quantize activations to 8 bits.
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
:param exclude_module_types: optional list of module class names
to not propagate quantization configs to. Default is None
:param activation_qconfig_kwargs: Additional kwargs for quantization of
activations.
:param weight_qconfig_kwargs: Additional kwargs for quantization of
weights.
"""

def __init__(
Expand All @@ -132,8 +140,11 @@ def __init__(
quantize_embeddings: bool = True,
reduce_range: bool = False,
quantize_linear_activations: bool = True,
activation_bits: Optional[int] = None,
num_calibration_steps: Optional[int] = None,
exclude_module_types: Union[List[str], None] = None,
activation_qconfig_kwargs: Optional[Dict[str, Any]] = None,
weight_qconfig_kwargs: Optional[Dict[str, Any]] = None,
):
if torch_quantization is None or torch_intrinsic is None:
raise RuntimeError(
Expand All @@ -158,12 +169,16 @@ def __init__(
self._quantize_embeddings = quantize_embeddings
self._reduce_range = reduce_range
self._quantize_linear_activations = quantize_linear_activations
self._activation_bits = activation_bits
self._exclude_module_types = exclude_module_types

self._modules_to_quantize = None
self._qat_enabled = False
self._quantization_observer_disabled = False
self._bn_stats_frozen = False
self._activation_qconfig_kwargs = activation_qconfig_kwargs
self._weight_qconfig_kwargs = weight_qconfig_kwargs

self._calibration_dataloader = None
self._calibration_function = None
self._num_calibration_steps = num_calibration_steps
Expand Down Expand Up @@ -309,6 +324,32 @@ def exclude_module_types(self) -> Union[List[str], None]:
"""
return self._exclude_module_types

@ModifierProp()
def activation_bits(self) -> Optional[int]:
"""
:return: Number of bits to be use for setting quant min/max values for
activations. Default is None, which will quantize activations to 8 bits.
"""
return self._activation_bits

@ModifierProp()
def activation_qconfig_kwargs(self) -> Dict[str, Any]:
"""
:return: Dictionary with correct quant_min, quant_max, and dtype values
for activations

"""
return self._activation_qconfig_kwargs

@ModifierProp()
def weight_qconfig_kwargs(self) -> Dict[str, Any]:
"""
:return: Dictionary with correct quant_min, quant_max, and dtype values
for weights

"""
return self._weight_qconfig_kwargs

@ModifierProp()
def num_calibration_steps(self) -> Optional[int]:
"""
Expand Down Expand Up @@ -457,11 +498,22 @@ def _enable_module_qat(self, module: Module):
self._model_fuse_fn_kwargs["inplace"] = True
fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs)

activation_qconfig_kwargs = self._get_updated_activation_qconfig_kwargs()

# prepare each module / submodule for quantization
qconfig = get_qat_qconfig(reduce_range=self._reduce_range)
qconfig = get_qat_qconfig(
reduce_range=self._reduce_range,
activation_qconfig_kwargs=activation_qconfig_kwargs,
weight_qconfig_kwargs=self.weight_qconfig_kwargs,
)
for name, quant_module in self._modules_to_quantize:
# wrap any modules with wrap_qat set to True as QATWrapper(s)
configure_module_qat_wrappers(quant_module, reduce_range=self._reduce_range)
configure_module_qat_wrappers(
quant_module,
reduce_range=self._reduce_range,
activation_qconfig_kwargs=activation_qconfig_kwargs,
weight_qconfig_kwargs=self.weight_qconfig_kwargs,
)
# set quantization config (asymmetric activations, symmetric weights)
quant_module.qconfig = qconfig
# wrap all conv / linear blocks in with quantization observers
Expand All @@ -480,7 +532,15 @@ def _enable_module_qat(self, module: Module):
# set modules with proper qconfigs to QAT mode
torch_quantization.prepare_qat(module, inplace=True)
if self._quantize_embeddings:
prepare_embeddings_qat(module, reduce_range=self._reduce_range)
prepare_embeddings_qat(
module,
reduce_range=self._reduce_range,
activation_qconfig_kwargs=activation_qconfig_kwargs,
weight_qconfig_kwargs=self.weight_qconfig_kwargs,
)

# propagate custom quant min/max range from FakeQuantize to Observer objects
fix_observer_quant_range(module)

self._qat_enabled = True
self._calibrate_if_possible(module)
Expand Down Expand Up @@ -533,6 +593,35 @@ def _calibrate(self, module):
if module_training:
module.train()

def _get_updated_activation_qconfig_kwargs(self):
activation_qconfig_kwargs = (
self.activation_qconfig_kwargs.copy()
if self.activation_qconfig_kwargs
else {}
)

# update qconfig_kwargs for activation_bits
if self.activation_bits and (
activation_qconfig_kwargs.get("quant_min")
or activation_qconfig_kwargs.get("quant_max")
):
raise ValueError(
"Cannot override quant_max and quant_min with activation_bits enabled"
)

if self.activation_bits:
quant_min = 0
quant_max = 2 ** self.activation_bits - 1
dtype = torch.quint8
activation_qconfig_kwargs.update(
dict(
quant_min=quant_min,
quant_max=quant_max,
dtype=dtype,
)
)
return activation_qconfig_kwargs

def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:
return (
self._disable_quantization_observer_epoch is not None
Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@
QATLinear = None
QATConv2d = None

from sparseml.utils import create_dirs, save_numpy


try:
from torch.nn.qat import Conv3d as QATConv3d
except Exception as _err:
quant_conv3d_err = _err
QATConv3d = None

from sparseml.utils import create_dirs, save_numpy


__all__ = [
"default_device",
"device_of",
Expand Down
Loading