Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 0 additions & 86 deletions src/sparseml/onnx/utils/graph_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
__all__ = [
"fold_conv_bns",
"quantize_resnet_identity_add_inputs",
"quantized_residual_add_optim",
]


Expand Down Expand Up @@ -202,91 +201,6 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo
return optimization_made


def quantized_residual_add_optim(quantized_model: onnx.ModelProto) -> bool:
"""
This optimization adds a quant/dequant block to the identity branch of a
residual whose non-identity branch is quantized. This enables the add at the
end of the residual to be fused at runtime.

Function will match to any node who has two children nodes - one add node
and one quantize node whose branch eventually leads to the other add node.

:param quantized_model: A loaded quantized model to perform this optimization on
:return: True if an in-place optimization was made
"""
graph = ONNXGraph(quantized_model)
optimization_made = False
for node in quantized_model.graph.node:
children_nodes = graph.get_node_children(node)
if len(children_nodes) != 2:
continue

add_node = [node for node in children_nodes if node.op_type == "Add"]
quant_node = [
node for node in children_nodes if node.op_type == "QuantizeLinear"
]
if not add_node or not quant_node:
continue
add_node = add_node[0]
quant_node = quant_node[0]

# verify that quant_node eventually leads to add_node
curr_node = [quant_node]
iter = 0
max_iter = 20 # avoid cycles
while curr_node and curr_node[0] != add_node and iter < max_iter:
curr_node = graph.get_node_children(curr_node[0])
iter += 1
if curr_node[0] != add_node:
continue

# create de-quantize node for identity
dequant_node = _make_dequant_node_for_quant(quant_node)

# update graph
identity_edge_idx = 0 if add_node.input[0] == node.output[0] else 1
graph.add_node(dequant_node)
graph.update_node_input(add_node, dequant_node.output[0], identity_edge_idx)
optimization_made = True

# if any of the add children have are a quantize op while others aren't
# add a quant/dequant block to the non quantized paths to allow for fusion
# of the add
add_node_children = graph.get_node_children(add_node)
add_node_quant_child_idx = [
idx
for idx, node in enumerate(add_node_children)
if node.op_type == "QuantizeLinear"
]
if not add_node_quant_child_idx or all(
n.op_type == "Add" or n.op_type == "QuantizeLinear"
for n in add_node_children
):
# no quant child node, or all child nodes are quant/add nodes
continue

# make dequant pair node for quant child and add to graph
add_node_dequant_child = _make_dequant_node_for_quant(
add_node_children[add_node_quant_child_idx[0]]
)
graph.add_node(add_node_dequant_child)

# update all non quant node children to take the quant/dequant block as input
for add_child_node in add_node_children:
if add_child_node.op_type == "QuantizeLinear":
continue
add_node_id_idx = [
idx
for idx, output_id in enumerate(add_child_node.input)
if output_id == add_node.output[0]
][0]
graph.update_node_input(
add_child_node, add_node_dequant_child.output[0], add_node_id_idx
)

return optimization_made


def _make_dequant_node_for_quant(quant_node: onnx.NodeProto) -> onnx.NodeProto:
return onnx.helper.make_node(
"DequantizeLinear",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy
import onnx
import torch
from onnx import ModelProto, NodeProto, numpy_helper

from sparseml.onnx.utils import (
Expand All @@ -34,7 +35,6 @@
get_node_attributes,
get_node_output_nodes,
quantize_resnet_identity_add_inputs,
quantized_residual_add_optim,
remove_node_and_params_from_graph,
swap_node_output,
update_model_param,
Expand Down Expand Up @@ -323,9 +323,21 @@ def _attribute_to_kwarg(attribute: onnx.AttributeProto):
def _quantize_array(
array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8
) -> numpy.ndarray:
dmin = numpy.iinfo(dtype).min
dmax = numpy.iinfo(dtype).max
return ((array / scale).round() + zero_point).clip(dmin, dmax).astype(dtype)
if dtype == numpy.uint8:
tensor_dtype = torch.quint8
elif dtype == numpy.int8:
tensor_dtype = torch.qint8
elif dtype == numpy.int32:
tensor_dtype = torch.qint32

tensor = torch.Tensor(array).to(torch.float32)
if isinstance(scale, numpy.ndarray):
scale = scale.item()
if isinstance(zero_point, numpy.ndarray):
zero_point = zero_point.item()

quant_tensor = torch.quantize_per_tensor(tensor, scale, zero_point, tensor_dtype)
return quant_tensor.int_repr().numpy()


def _convert_quantizable_conv(
Expand Down Expand Up @@ -450,6 +462,7 @@ def _convert_quantizable_gemm(
weight_quantize_params.target,
weight_quantize_params.scale,
weight_quantize_params.zero_point,
weight_quantize_params.zero_point.dtype,
)
quantized_weight = quantized_weight.transpose() # Gemm has implicit transpose
quantized_weight_name = "{}.weight_quantized".format(gemm_node.name)
Expand Down Expand Up @@ -732,6 +745,7 @@ def _add_quantized_conv_matmul_add_ops(
weight_quantize_params.target,
weight_quantize_params.scale,
weight_quantize_params.zero_point,
weight_quantize_params.zero_point.dtype,
)
if transpose_weight:
quantized_weight = quantized_weight.transpose()
Expand Down Expand Up @@ -1404,7 +1418,9 @@ def _quantize_qat_embedding(model: ModelProto):
embedding = numpy_helper.to_array(embedding_initializer)
scale = numpy_helper.to_array(scale_initializer)
zero_point = numpy_helper.to_array(zp_initializer)
embedding_quant = _quantize_array(embedding, scale, zero_point)
embedding_quant = _quantize_array(
embedding, scale, zero_point, zero_point.dtype
)
embedding_quant_initializer = numpy_helper.from_array(
embedding_quant, name=f"{embedding_initializer.name}_quant"
)
Expand Down Expand Up @@ -1569,7 +1585,6 @@ def quantize_torch_qat_export(
_convert_quantizable_gemm_no_activations(model)
_quantize_qat_embedding(model)
quantize_resnet_identity_add_inputs(model)
quantized_residual_add_optim(model)
_remove_duplicate_quantize_ops(model)
_cleanup_unused_quants(model)

Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import date


version_base = "1.0.0"
version_base = "1.0.1"
is_release = False # change to True to set the generated version as a release version


Expand Down