Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"flake8==3.9.2",
"isort==5.8.0",
"m2r2~=0.2.7",
"mistune==0.8.4",
"myst-parser~=0.14.0",
"rinohtype~=0.4.2",
"sphinx~=3.5.0",
Expand Down
27 changes: 27 additions & 0 deletions src/sparseml/pytorch/optim/modifier_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class QuantizationModifier(ScheduledModifier):
transformer based models such as BERT where the quantized MatMul outputs
are kept at 32 bits of precision and fake quantizing the outputs harm training
recovery. Default is True
:param exclude_module_types: optional list of module class names
to not propagate quantization configs to. Default is None
"""

def __init__(
Expand All @@ -114,6 +116,7 @@ def __init__(
quantize_embeddings: bool = True,
reduce_range: bool = False,
quantize_linear_activations: bool = True,
exclude_module_types: Union[List[str], None] = None,
):
if torch_quantization is None or torch_intrinsic is None:
raise RuntimeError(
Expand All @@ -138,6 +141,7 @@ def __init__(
self._quantize_embeddings = quantize_embeddings
self._reduce_range = reduce_range
self._quantize_linear_activations = quantize_linear_activations
self._exclude_module_types = exclude_module_types

self._modules_to_quantize = None
self._qat_enabled = False
Expand Down Expand Up @@ -278,6 +282,14 @@ def quantize_linear_activations(self) -> bool:
"""
return self._quantize_linear_activations

@ModifierProp()
def exclude_module_types(self) -> Union[List[str], None]:
"""
:return: optional list of module class names to not propagate
quantization configs to. Default is None
"""
return self._exclude_module_types

def initialize(
self,
module: Module,
Expand Down Expand Up @@ -423,10 +435,15 @@ def _enable_module_qat(self, module: Module):
if not self._quantize_linear_activations:
remove_activation_qat_by_layer_name(quant_module, ["Linear"])

# remove qconfigs for module types in exclude_module_types
if self._exclude_module_types:
self._strip_excluded_module_qconfigs(module)

# set modules with proper qconfigs to QAT mode
torch_quantization.prepare_qat(module, inplace=True)
if self._quantize_embeddings:
prepare_embeddings_qat(module, reduce_range=self._reduce_range)

self._qat_enabled = True

def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:
Expand All @@ -443,6 +460,16 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
and not self._bn_stats_frozen
)

def _strip_excluded_module_qconfigs(self, module: Module):
if not self._exclude_module_types:
return
excluded_classes = set(self._exclude_module_types)
for submodule in module.modules():
if submodule.__class__.__name__ in excluded_classes and hasattr(
submodule, "qconfig"
):
submodule.qconfig = None

def _validate_params(self):
if (
self._disable_quantization_observer_epoch is not None
Expand Down
28 changes: 20 additions & 8 deletions src/sparseml/pytorch/utils/quantization/quantize_qat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ def _delete_repeated_qat_blocks(model: ModelProto):
nodes_to_delete.append(dequant_node_1)

for n in nodes_to_delete:
delete_quant_node(model, n)
delete_quant_node(model, n, keep_params=True)

# cleanup graph
graph.update()
graph.delete_unused_initializers()


def _attribute_to_kwarg(attribute: onnx.AttributeProto):
Expand Down Expand Up @@ -1214,12 +1218,14 @@ def _quantize_qat_embedding(model: ModelProto):
qdq_output = False

if qdq_output:
# forward gather output to dequant input
output_dequant_node.input[0] = gather_node.output[0]
output_dequant_node.input[1] = input_quant_node.input[1]
output_dequant_node.input[2] = input_quant_node.input[2]
# delete unnecessary quantize and dequantize ops
delete_quant_node(model, input_quant_node, keep_params=False)
delete_quant_node(model, input_quant_node, keep_params=True)
delete_quant_node(model, input_dequant_node, keep_params=False)
delete_quant_node(model, output_quant_node, keep_params=False)
# forward gather output to dequant input
output_dequant_node.input[0] = gather_node.output[0]

else:
# use input dequant to dequantize output
Expand Down Expand Up @@ -1265,7 +1271,10 @@ def _remove_duplicate_quantize_ops(model: ModelProto):
_replace_input_id_model(
model, remove_node.output[0], keep_node.output[0]
)
remove_node_and_params_from_graph(model, remove_node)
delete_quant_node(model, remove_node, keep_params=True)
# cleanup graph
graph.update()
graph.delete_unused_initializers()


def _cleanup_unused_quants(model: ModelProto):
Expand Down Expand Up @@ -1296,15 +1305,18 @@ def _cleanup_unused_quants(model: ModelProto):
continue

# Forward QuantizeLinear input to DequantizeLinear output
for child in dequant_children:
_replace_input_id_model(model, dequant_node.output[0], quant_node.input[0])
_replace_input_id_model(model, dequant_node.output[0], quant_node.input[0])

# Remove QuantizeLinear->DequantizeLinear block
nodes_to_delete.append(quant_node)
nodes_to_delete.append(dequant_node)

for n in nodes_to_delete:
delete_quant_node(model, n)
delete_quant_node(model, n, keep_params=True)

# update graph
graph.update()
graph.delete_unused_initializers()


def quantize_torch_qat_export(
Expand Down
21 changes: 21 additions & 0 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"""

import argparse
import collections
import inspect
import logging
import math
import os
Expand Down Expand Up @@ -180,13 +182,32 @@ def export_transformer_to_onnx(
inputs = tokenizer(
"", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value
).data # Dict[Tensor]

# Rearrange inputs' keys to match those defined by model foward func, which
# seem to define how the order of inputs is determined in the exported model
forward_args_spec = inspect.getfullargspec(model.__class__.forward)
dropped = [f for f in inputs.keys() if f not in forward_args_spec.args]
inputs = collections.OrderedDict(
[
(f, inputs[f][0].reshape(1, -1))
for f in forward_args_spec.args
if f in inputs
]
)
if dropped:
_LOGGER.warning(
"The following inputs were not present in the model forward function "
f"and therefore dropped from ONNX export: {dropped}"
)

inputs_shapes = {
key: (
f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: "
f"{list(val.shape) if hasattr(val, 'shape') else 'unknown'}"
)
for key, val in inputs.items()
}

_LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}")

# run export
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import date


version_base = "0.10.0"
version_base = "0.10.1"
is_release = False # change to True to set the generated version as a release version


Expand Down
32 changes: 21 additions & 11 deletions tests/sparseml/pytorch/optim/test_modifier_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
start_epoch=0.0,
quantize_linear_activations=False,
),
lambda: QuantizationModifier(
start_epoch=0.0,
exclude_module_types=["Linear"],
),
]


Expand All @@ -67,9 +71,13 @@ def _is_quantiable_module(module):
return isinstance(module, (Conv2d, Linear))


def _test_quantizable_module(
module, qat_expected, reduce_range, quantize_linear_activations
):
def _test_quantizable_module(module, qat_expected, modifier):
reduce_range = modifier.reduce_range
quantize_linear_activations = modifier.quantize_linear_activations

excluded_types = modifier.exclude_module_types or []
qat_expected = qat_expected and module.__class__.__name__ not in excluded_types

if qat_expected:
assert hasattr(module, "qconfig") and module.qconfig is not None
assert hasattr(module, "weight_fake_quant") and (
Expand Down Expand Up @@ -97,12 +105,7 @@ def _test_qat_applied(modifier, model):
submodules = [""]
for module in model.modules():
if _is_quantiable_module(module):
_test_quantizable_module(
module,
True,
modifier.reduce_range,
modifier.quantize_linear_activations,
)
_test_quantizable_module(module, True, modifier)
else:
assert not hasattr(model, "qconfig") or model.qconfig is None
submodules = modifier.submodules
Expand All @@ -112,8 +115,7 @@ def _test_qat_applied(modifier, model):
_test_quantizable_module(
module,
_is_valid_submodule(name, submodules),
modifier.reduce_range,
modifier.quantize_linear_activations,
modifier,
)


Expand Down Expand Up @@ -207,6 +209,7 @@ def test_quantization_modifier_yaml():
quantize_embeddings = False
reduce_range = True
quantize_linear_activations = False
exclude_module_types = ["LayerNorm", "Tanh"]
yaml_str = f"""
!QuantizationModifier
start_epoch: {start_epoch}
Expand All @@ -217,6 +220,7 @@ def test_quantization_modifier_yaml():
quantize_embeddings: {quantize_embeddings}
reduce_range: {reduce_range}
quantize_linear_activations: {quantize_linear_activations}
exclude_module_types: {exclude_module_types}
"""
yaml_modifier = QuantizationModifier.load_obj(
yaml_str
Expand All @@ -233,6 +237,7 @@ def test_quantization_modifier_yaml():
quantize_embeddings=quantize_embeddings,
reduce_range=reduce_range,
quantize_linear_activations=quantize_linear_activations,
exclude_module_types=exclude_module_types,
)

assert isinstance(yaml_modifier, QuantizationModifier)
Expand Down Expand Up @@ -276,3 +281,8 @@ def test_quantization_modifier_yaml():
== serialized_modifier.quantize_linear_activations
== obj_modifier.quantize_linear_activations
)
assert (
yaml_modifier.exclude_module_types
== serialized_modifier.exclude_module_types
== obj_modifier.exclude_module_types
)