Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add onnx exporters #475

Merged
merged 7 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,38 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]):

data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict')


def post_process_for_mmdeploy(self,
dummy_input: Tuple = (1, 3, 224, 224)):
"""Prepare for deploy to the backend with mmdeploy, which will be used
in mmdeploy, and usually includes as follows:

1. prepare for the float model rewritten by mmdeploy.
2. load checkpoint consists of float weight and quantized params in
mmrazor.
3. post process weight fakequant for exporting .onnx that meet
the backend's requirement.
"""

quantized_state_dict = self.qmodels['predict'].state_dict()
fp32_model = self.architecture
self.quantizer.convert_batchnorm2d(fp32_model)

# concrete_args = {'mode': 'predict'}
observed_model = self.quantizer.prepare(fp32_model)

if dummy_input is not None:
observed_model(torch.randn(dummy_input))

observed_model.load_state_dict(quantized_state_dict)

self.quantizer.post_process_for_deploy(
observed_model, keep_fake_quant=True)

# self.qmodels['predict'] = observed_model

return observed_model


@MODEL_WRAPPERS.register_module()
Expand Down
7 changes: 7 additions & 0 deletions mmrazor/models/quantizers/exporters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .openvino_quantize_exporter import OpenVinoQuantizeExportor
from .tensorrt_quantize_exporter import (TensorRTQTableExporter,
TensorRTExplicitExporter)

__all__ = ['OpenVinoQuantizeExportor', 'TensorRTQTableExporter',
'TensorRTExplicitExporter']
328 changes: 328 additions & 0 deletions mmrazor/models/quantizers/exporters/base_quantize_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import onnx
from mmengine import print_log
from onnx import numpy_helper

from .optim_utils import ONNXOptimUtils

SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose']

PERCHANNEL_FAKEQUANTIZER = [
'FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine'
]
PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', 'FixedPerTensorAffine']

ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER


def _parse_attrs(node_attrs):
attrs = {}
for attr in node_attrs:
if attr.type == onnx.AttributeProto.AttributeType.INTS:
attrs[attr.name] = tuple(attr.ints)
elif attr.type == onnx.AttributeProto.AttributeType.INT:
attrs[attr.name] = attr.i
elif attr.type == onnx.AttributeProto.AttributeType.FLOATS:
attrs[attr.name] = tuple(attr.floats)
elif attr.type == onnx.AttributeProto.AttributeType.FLOAT:
attrs[attr.name] = attr.f
elif attr.type == onnx.AttributeProto.AttributeType.TENSOR:
attrs[attr.name] = numpy_helper.to_array(attr.t)
elif attr.type == onnx.AttributeProto.AttributeType.STRING:
attrs[attr.name] = str(attr.s)
elif attr.type == onnx.AttributeProto.AttributeType.STRINGS:
attrs[attr.name] = tuple([str(x) for x in attr.strings])
else:
raise Exception('ATTR Type [{}] Not Supported!'.format(attr.type))
return attrs


class BaseQuantizeExportor():

optimizer = ONNXOptimUtils

def __init__(self, onnx_model, export_path) -> None:

if isinstance(onnx_model, str):
self.onnx_model = onnx.load(onnx_model)
elif isinstance(onnx_model, onnx.ModelProto):
self.onnx_model = onnx_model
else:
raise TypeError

self.export_path = export_path
self._init_mappings_from_onnx(self.onnx_model)

self.optimizer.remove_fake_pad_op(self.onnx_model, self.name2data,
self.input2node, self.output2node)

self._remap_input_and_node()
self._remap_output_and_node()

@property
def graph(self):
"""The onnx model's graph."""
return self.onnx_model.graph

def _init_mappings_from_onnx(self, onnx_model):
"""Build necessary mappings in a onnx model."""

self.input2node = self.optimizer.map_input_and_node(onnx_model)
self.output2node = self.optimizer.map_output_and_node(onnx_model)
self.name2data = self.optimizer.map_name_and_data(onnx_model)

# todo: maybe useless
# self.name2init = self.optimizer.map_name_and_initializer(onnx_model)

def _remap_input_and_node(self):
"""Rebuild the mapping from input name to a (node, input index)
tuple."""
self.input2node = self.optimizer.map_input_and_node(self.onnx_model)

def _remap_output_and_node(self):
"""Rebuild the mapping from a node's output name to this node."""
self.output2node = self.optimizer.map_output_and_node(self.onnx_model)

def parse_qparams(self, node: onnx.NodeProto):
"""Parse the quantize-related parameters based on a node."""
tensor_name, scale, zero_point = node.input[:3]

scale, zero_point = self.name2data[scale], self.name2data[zero_point]
if len(node.input) > 3:
qmin, qmax = node.input[-2:]
qmin, qmax = self.name2data[qmin], self.name2data[qmax]
elif len(node.attribute) > 0:
qparams = _parse_attrs(node.attribute)
qmin = qparams['quant_min']
qmax = qparams['quant_max']
else:
print_log(f'qmin and qmax are not found for <{node.name}>!')
qmax = qmin = None
return tensor_name, scale, zero_point, qmin, qmax

def collect_symbolic_nodes(self, onnx_model: onnx.ModelProto):
"""Collect all the fakequant nodes from a onnx model."""
symbolic_nodes = list()
for node in onnx_model.graph.node:
if node.op_type in ALL_FAKEQUANTIZER:
symbolic_nodes.append(node)
return symbolic_nodes

def _get_constant_inputs(self, node: onnx.NodeProto):
"""Get the constant input node for the current node."""
constant_nodes = list()
output2node = self.output2node
for inp in node.input:
if inp in output2node and output2node[inp].op_type == 'Constant':
cnode = output2node[inp]

constant_nodes.append(cnode)
return constant_nodes

def _collect_symbolic_constant_inputs(self, symbolic_nodes: List):
"""Collect these constant nodes which is the input of all the symbolic
node."""

collected_constant_names = set()
constant_inputs = list()
for node in symbolic_nodes:
constant_inputs = self._get_constant_inputs(node)
for constant in constant_inputs:
if constant.name in collected_constant_names:
continue
constant_inputs.append(constant)
collected_constant_names.add(constant.name)
return constant_inputs

def _remove_symbolic_related_from_onnx(self, symbolic_nodes: List,
symbolic_constant_inputs: List):
"""Remove these out of date fakequant nodes and theirs constant input
nodes."""
for node in symbolic_nodes:
self.onnx_model.graph.node.remove(node)

# Remove symbolic related constant nodes. The constant node which is
# only used by those symbolic nodes can be removed.

def _is_standalone_constant_node(constant):
for node in self.onnx_model.graph.node:
for input_name in node.input:
# A constant node always has one output.
if input_name == constant.output[0]:
return False
return True

for constant in symbolic_constant_inputs:
if _is_standalone_constant_node(constant):
self.onnx_model.graph.node.remove(constant)

def export(self):
"""Export end to end onnx model."""
# todo: is it a abstract method?
raise NotImplementedError


class QTableQuantizeExportor(BaseQuantizeExportor):
HIT-cwh marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, onnx_model, export_path) -> None:
super().__init__(onnx_model, export_path)

self._qtables = dict() # type: ignore

@property
def qtables(self):
return self._qtables

def register_qtables(self, value, key):
assert value not in self._qtables
self._qtables[value] = key

def post_process_qtables(self):

def find_the_closest_tensor(node):
if node.input[0] in self.qtables:
return node.input[0]
elif (node.op_type in ['Flatten', 'Resize']
and node.output[0] in self.input2node):

next_node = self.input2node[node.output[0]][0][0]
return find_the_closest_tensor(next_node)
else:
return None

for node in self.graph.node:
if node.op_type in ['Flatten', 'Resize']:
tensor_name = find_the_closest_tensor(node)
if tensor_name:
self.qtables[node.input[0]] = self.qtables[tensor_name]
print_log(
f'Pass <{tensor_name}> clip range to <{node.name}> '
f'input <{node.input[0]}>.')

def _is_fakequant_for_weight(self, node):

if node.output[0] not in self.input2node:

assert node.output[0] in [out.name for out in self.graph.output], \
f'{node.name} not in graph.'

self.input2node[node.output[0]] = []
next_nodes = self.input2node[node.output[0]]

flag = True
for n in next_nodes:
if n[1] == 1 and n[0].op_type in SUPPORT_QWEIGHT_NODE:
continue
else:
flag = False
break

return flag

def _is_fakequant_for_bias(self, node):

if node.output[0] not in self.input2node:

assert node.output[0] in [out.name for out in self.graph.output], \
f'{node.name} not in graph.'

self.input2node[node.output[0]] = []
next_nodes = self.input2node[node.output[0]]

flag = True
for n in next_nodes:
if n[1] == 2 and n[0].op_type in SUPPORT_QWEIGHT_NODE:
continue
else:
flag = False
break

return flag

def _is_fakequant_for_activation(self, node):

return (not self._is_fakequant_for_weight(node)
and not self._is_fakequant_for_bias(node))

def deal_with_weight_fakequant(self, node):

next_nodes = self.input2node[node.output[0]]
next_node = next_node, idx = next_nodes[0]
next_node.input[idx] = node.input[0]

def deal_with_activation_fakequant(self, node):
next_nodes = self.input2node[node.output[0]]
for next_node, idx in next_nodes:
next_node.input[idx] = node.input[0]

def deal_with_per_channel_node(self, node):
# fake quantize for weights
# suppose per-channel quantize only for weight
if not self.is_fakequant_for_weight(node):
raise RuntimeError('Only support per-channel quantize for weight')
self.deal_with_weight_fakequant(node)

def deal_with_per_tensor_node(self, node):

if self._is_fakequant_for_weight(node):
self.deal_with_per_tensor_weight(node)
elif self._is_fakequant_for_bias(node):
self.deal_with_per_tensor_bias(node)
elif self._is_fakequant_for_activation(node):
self.deal_with_per_tensor_activation(node)
else:
raise NotImplementedError

def deal_with_per_tensor_weight(self, node):
# fake quantize for weights
self.deal_with_weight_fakequant(node)

def deal_with_per_tensor_bias(self, node):
# fake quantize for bias
raise RuntimeError(f"{self.backend} don't support per-tensor quantize "
f'for bias')

def deal_with_per_tensor_activation(self, node):

# fake quantize for activations

self.deal_with_activation_fakequant(node)
tensor_name, _, _, _, _ = self.parse_qparams(node)
for out in self.graph.output:
if out.name == node.output[0]:
out.name = tensor_name

def _remove_symbolic_and_collect_params(self):
symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model)

collect_func = self._collect_symbolic_constant_inputs
symbolic_constant_inputs = collect_func(symbolic_nodes)

for node in symbolic_nodes:
if node.op_type in PERCHANNEL_FAKEQUANTIZER:
self.deal_with_per_channel_node(node)
else:
self.deal_with_per_tensor_node(node)

self._remove_symbolic_related_from_onnx(symbolic_nodes,
symbolic_constant_inputs)

self.optimizer.optimize(self.onnx_model)

self._remap_input_and_node()
self._remap_output_and_node()

self.post_process_qtables()

def export_qtables(self):

pass

def export(self):
self._remove_symbolic_and_collect_params()

onnx.save(self.onnx_model, self.export_path)

self.export_qtables()
Loading