Skip to content

Commit

Permalink
[Feature] Add onnx exporters (#475)
Browse files Browse the repository at this point in the history
* fix del redundant fakequant

* add onnx exporters

* fix onnx exporters and add docstring

* fix comments

* delete useless codes

* fix export_onnx in native quantizer

---------

Co-authored-by: pppppM <gjf_mail@126.com>
  • Loading branch information
2 people authored and humu789 committed Apr 17, 2023
1 parent 0cd361d commit d6a7ea5
Show file tree
Hide file tree
Showing 8 changed files with 735 additions and 43 deletions.
32 changes: 31 additions & 1 deletion mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

try:
from torch.ao.quantization import (FakeQuantizeBase, MinMaxObserver,
PerChannelMinMaxObserver)
PerChannelMinMaxObserver,
disable_observer)
except ImportError:
from mmrazor.utils import get_placeholder
FakeQuantizeBase = get_placeholder('torch>=1.13')
MinMaxObserver = get_placeholder('torch>=1.13')
PerChannelMinMaxObserver = get_placeholder('torch>=1.13')
disable_observer = get_placeholder('torch>=1.13')

LossResults = Dict[str, torch.Tensor]
TensorResults = Union[Tuple[torch.Tensor], torch.Tensor]
Expand Down Expand Up @@ -213,6 +215,34 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]):
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict')

def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)):
"""Prepare for deploy to the backend with mmdeploy, which will be used
in mmdeploy, and usually includes as follows:
1. prepare for the float model rewritten by mmdeploy.
2. load checkpoint consists of float weight and quantized params in
mmrazor.
3. post process weight fakequant for exporting .onnx that meet
the backend's requirement.
"""

quantized_state_dict = self.qmodels['tensor'].state_dict()
fp32_model = self.architecture
self.quantizer.convert_batchnorm2d(fp32_model)
observed_model = self.quantizer.prepare(fp32_model, {'mode': 'tensor'})

if dummy_input is not None:
observed_model(torch.randn(dummy_input))

observed_model.load_state_dict(quantized_state_dict)

self.quantizer.post_process_for_deploy(
observed_model, keep_w_fake_quant=True)

observed_model.apply(disable_observer)

return observed_model


@MODEL_WRAPPERS.register_module()
class MMArchitectureQuantDDP(MMDistributedDataParallel):
Expand Down
5 changes: 5 additions & 0 deletions mmrazor/models/quantizers/exporters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .openvino_quantize_exporter import OpenVinoQuantizeExportor
from .tensorrt_quantize_exporter import TensorRTExplicitExporter

__all__ = ['OpenVinoQuantizeExportor', 'TensorRTExplicitExporter']
164 changes: 164 additions & 0 deletions mmrazor/models/quantizers/exporters/base_quantize_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import onnx
from mmengine import print_log
from onnx import numpy_helper

from .optim_utils import ONNXOptimUtils

SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose']

PERCHANNEL_FAKEQUANTIZER = [
'FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine'
]
PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', 'FixedPerTensorAffine']

ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER


def _parse_attrs(node_attrs):
attrs = {}
for attr in node_attrs:
if attr.type == onnx.AttributeProto.AttributeType.INTS:
attrs[attr.name] = tuple(attr.ints)
elif attr.type == onnx.AttributeProto.AttributeType.INT:
attrs[attr.name] = attr.i
elif attr.type == onnx.AttributeProto.AttributeType.FLOATS:
attrs[attr.name] = tuple(attr.floats)
elif attr.type == onnx.AttributeProto.AttributeType.FLOAT:
attrs[attr.name] = attr.f
elif attr.type == onnx.AttributeProto.AttributeType.TENSOR:
attrs[attr.name] = numpy_helper.to_array(attr.t)
elif attr.type == onnx.AttributeProto.AttributeType.STRING:
attrs[attr.name] = str(attr.s)
elif attr.type == onnx.AttributeProto.AttributeType.STRINGS:
attrs[attr.name] = tuple([str(x) for x in attr.strings])
else:
raise Exception('ATTR Type [{}] Not Supported!'.format(attr.type))
return attrs


class BaseQuantizeExportor():

optimizer = ONNXOptimUtils

def __init__(self, onnx_model, export_path) -> None:

if isinstance(onnx_model, str):
self.onnx_model = onnx.load(onnx_model)
elif isinstance(onnx_model, onnx.ModelProto):
self.onnx_model = onnx_model
else:
raise TypeError

self.export_path = export_path
self._init_mappings_from_onnx(self.onnx_model)

self.optimizer.remove_fake_pad_op(self.onnx_model, self.name2data,
self.input2node, self.output2node)

self._remap_input_and_node()
self._remap_output_and_node()

@property
def graph(self):
"""The onnx model's graph."""
return self.onnx_model.graph

def _init_mappings_from_onnx(self, onnx_model):
"""Build necessary mappings in a onnx model."""

self.input2node = self.optimizer.map_input_and_node(onnx_model)
self.output2node = self.optimizer.map_output_and_node(onnx_model)
self.name2data = self.optimizer.map_name_and_data(onnx_model)

# todo: maybe useless
# self.name2init = self.optimizer.map_name_and_initializer(onnx_model)

def _remap_input_and_node(self):
"""Rebuild the mapping from input name to a (node, input index)
tuple."""
self.input2node = self.optimizer.map_input_and_node(self.onnx_model)

def _remap_output_and_node(self):
"""Rebuild the mapping from a node's output name to this node."""
self.output2node = self.optimizer.map_output_and_node(self.onnx_model)

def parse_qparams(self, node: onnx.NodeProto):
"""Parse the quantize-related parameters based on a node."""
tensor_name, scale, zero_point = node.input[:3]

scale, zero_point = self.name2data[scale], self.name2data[zero_point]
if len(node.input) > 3:
qmin, qmax = node.input[-2:]
qmin, qmax = self.name2data[qmin], self.name2data[qmax]
elif len(node.attribute) > 0:
qparams = _parse_attrs(node.attribute)
qmin = qparams['quant_min']
qmax = qparams['quant_max']
else:
print_log(f'qmin and qmax are not found for <{node.name}>!')
qmax = qmin = None
return tensor_name, scale, zero_point, qmin, qmax

def collect_symbolic_nodes(self, onnx_model: onnx.ModelProto):
"""Collect all the fakequant nodes from a onnx model."""
symbolic_nodes = list()
for node in onnx_model.graph.node:
if node.op_type in ALL_FAKEQUANTIZER:
symbolic_nodes.append(node)
return symbolic_nodes

def _get_constant_inputs(self, node: onnx.NodeProto):
"""Get the constant input node for the current node."""
constant_nodes = list()
output2node = self.output2node
for inp in node.input:
if inp in output2node and output2node[inp].op_type == 'Constant':
cnode = output2node[inp]

constant_nodes.append(cnode)
return constant_nodes

def _collect_symbolic_constant_inputs(self, symbolic_nodes: List):
"""Collect these constant nodes which is the input of all the symbolic
node."""

collected_constant_names = set()
constant_inputs = list()
for node in symbolic_nodes:
constant_inputs = self._get_constant_inputs(node)
for constant in constant_inputs:
if constant.name in collected_constant_names:
continue
constant_inputs.append(constant)
collected_constant_names.add(constant.name)
return constant_inputs

def _remove_symbolic_related_from_onnx(self, symbolic_nodes: List,
symbolic_constant_inputs: List):
"""Remove these out of date fakequant nodes and theirs constant input
nodes."""
for node in symbolic_nodes:
self.onnx_model.graph.node.remove(node)

# Remove symbolic related constant nodes. The constant node which is
# only used by those symbolic nodes can be removed.

def _is_standalone_constant_node(constant):
for node in self.onnx_model.graph.node:
for input_name in node.input:
# A constant node always has one output.
if input_name == constant.output[0]:
return False
return True

for constant in symbolic_constant_inputs:
if _is_standalone_constant_node(constant):
self.onnx_model.graph.node.remove(constant)

def export(self):
"""Export end to end onnx model."""
# todo: is it a abstract method?
raise NotImplementedError
152 changes: 152 additions & 0 deletions mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import numpy as np
import onnx
from google.protobuf.internal.containers import RepeatedScalarFieldContainer
from onnx import helper, numpy_helper

from .base_quantize_exporter import BaseQuantizeExportor


class OpenVinoQuantizeExportor(BaseQuantizeExportor):

def __init__(self, onnx_model, export_path) -> None:
super().__init__(onnx_model, export_path)

def _build_backend_node_from_symbolic(self, node: onnx.NodeProto,
tensor_name: str, qmin: np.ndarray,
qmax: np.ndarray):
"""Build new onnx nodes which can be deployed to the specific backend.
These nodes will be used to replace those symbolic nodes in the
original onnx model.
"""
qmax = int(qmax)
qmin = int(qmin)
levels = qmax - qmin + 1
# adjust weight levels
# if levels == 128:
# levels = 256
# qmax = qmax * 2 + 1
# qmin = qmin * 2
output_name = node.output[0]
# Create a node (FakeQuantize)
keys = ['input_min', 'input_max', 'output_min', 'output_max']
input_names = [f'{tensor_name}_{key}' for key in keys]
backend_node = helper.make_node(
'FakeQuantize', # node name
[tensor_name, *input_names], # inputs
[output_name], # outputs
levels=levels, # Attributes
domain='org.openvinotoolkit',
name=node.name)
return backend_node

def _build_backend_initializer(self,
names: RepeatedScalarFieldContainer[str],
scale: np.ndarray, zero_point: np.ndarray,
qmin: np.ndarray, qmax: np.ndarray,
shape: List[int]):
"""Build onnx initializers which can be deployed to specific
backend."""

scale = np.abs(np.asarray(scale, dtype=np.float64).reshape(-1))
zero_point = np.clip(
np.asarray(np.round(zero_point), dtype=np.int32).reshape(-1),
a_min=qmin,
a_max=qmax)

qrange = float(qmax - qmin)
input_range = scale * qrange
input_high = (qmax - zero_point).astype(
np.float64) * input_range / qrange
input_low = input_high - input_range
input_low_size = input_low.size

if input_low_size != 1:
input_low = input_low.reshape(*shape)
input_high = input_high.reshape(*shape)

input_low = input_low.astype(np.float32)
input_high = input_high.astype(np.float32)

initializers = list()
for init_name, value_tensor in zip(
names, [input_low, input_high, input_low, input_high]):
init = numpy_helper.from_array(value_tensor)
init.name = init_name
initializers.append(init)
return initializers

def build_backend_nodes_and_initializers(self, symbolic_nodes: List):
"""Build new onnx nodes and initializers which can be deployed to
specific backend."""
backend_nodes = list()
backend_initializers = list()
for node in symbolic_nodes:
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(
node)
new_node = self._build_backend_node_from_symbolic(
node, tensor_name, qmin, qmax)
backend_nodes.append(new_node)

try:
# If the successor node (such as a conv node) has weight,
# we need get the length of the weight's shape. And ensure
# the length of the weight's shape and the new node's
# input shape (such as input_low and input_high) is the same.
next_node = self.input2node[node.output[0]][0][0]
# node for save weights
fake_node = self.output2node[next_node.input[1]]
tensor = self.name2data[fake_node.input[0]]
shape_length = len(tensor.shape)
new_shape = [-1] + [1] * (shape_length - 1)
except Exception:
new_shape = [-1]

# The first element of new_node.input is the tensor name.
new_init_names = new_node.input[1:]
new_initializers = self._build_backend_initializer(
new_init_names, scale, zero_point, qmin, qmax, new_shape)
backend_initializers.extend(new_initializers)
return backend_nodes, backend_initializers

def _insert_initializers_to_onnx(self, initializers: List):
"""Insert onnx initializers to the onnx graph."""
inserted_init_names = set()
for init in initializers:
if init.name in inserted_init_names:
continue

self.onnx_model.graph.initializer.append(init)
inserted_init_names.add(init.name)

def _replace_symbolic_related(self):
"""Replacing symbolic related nodes and initializers in the original
onnx model with new nodes and initializers that can be deployed to the
specific backend."""

symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model)

collect_func = self._collect_symbolic_constant_inputs
# Usually different activation fakequants share the same constant
# input, and different weight fakequants share the same constant input.
symbolic_constant_inputs = collect_func(symbolic_nodes)

build_func = self.build_backend_nodes_and_initializers
new_nodes, new_initializers = build_func(symbolic_nodes)

self._insert_initializers_to_onnx(new_initializers)

self._remove_symbolic_related_from_onnx(symbolic_nodes,
symbolic_constant_inputs)

self.onnx_model.graph.node.extend(new_nodes)
self.optimizer.optimize(self.onnx_model)

def export(self):
"""Export end to end onnx model."""
self._replace_symbolic_related()
onnx.save(self.onnx_model, self.export_path)
Loading

0 comments on commit d6a7ea5

Please sign in to comment.