diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index d3b0be089..ff99592d6 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -13,12 +13,14 @@ try: from torch.ao.quantization import (FakeQuantizeBase, MinMaxObserver, - PerChannelMinMaxObserver) + PerChannelMinMaxObserver, + disable_observer) except ImportError: from mmrazor.utils import get_placeholder FakeQuantizeBase = get_placeholder('torch>=1.13') MinMaxObserver = get_placeholder('torch>=1.13') PerChannelMinMaxObserver = get_placeholder('torch>=1.13') + disable_observer = get_placeholder('torch>=1.13') LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] @@ -213,6 +215,34 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') + def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + + quantized_state_dict = self.qmodels['tensor'].state_dict() + fp32_model = self.architecture + self.quantizer.convert_batchnorm2d(fp32_model) + observed_model = self.quantizer.prepare(fp32_model, {'mode': 'tensor'}) + + if dummy_input is not None: + observed_model(torch.randn(dummy_input)) + + observed_model.load_state_dict(quantized_state_dict) + + self.quantizer.post_process_for_deploy( + observed_model, keep_w_fake_quant=True) + + observed_model.apply(disable_observer) + + return observed_model + @MODEL_WRAPPERS.register_module() class MMArchitectureQuantDDP(MMDistributedDataParallel): diff --git a/mmrazor/models/quantizers/exporters/__init__.py b/mmrazor/models/quantizers/exporters/__init__.py new file mode 100644 index 000000000..b8153289d --- /dev/null +++ b/mmrazor/models/quantizers/exporters/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .openvino_quantize_exporter import OpenVinoQuantizeExportor +from .tensorrt_quantize_exporter import TensorRTExplicitExporter + +__all__ = ['OpenVinoQuantizeExportor', 'TensorRTExplicitExporter'] diff --git a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py new file mode 100644 index 000000000..7e1e1f375 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import onnx +from mmengine import print_log +from onnx import numpy_helper + +from .optim_utils import ONNXOptimUtils + +SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose'] + +PERCHANNEL_FAKEQUANTIZER = [ + 'FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine' +] +PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', 'FixedPerTensorAffine'] + +ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER + + +def _parse_attrs(node_attrs): + attrs = {} + for attr in node_attrs: + if attr.type == onnx.AttributeProto.AttributeType.INTS: + attrs[attr.name] = tuple(attr.ints) + elif attr.type == onnx.AttributeProto.AttributeType.INT: + attrs[attr.name] = attr.i + elif attr.type == onnx.AttributeProto.AttributeType.FLOATS: + attrs[attr.name] = tuple(attr.floats) + elif attr.type == onnx.AttributeProto.AttributeType.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == onnx.AttributeProto.AttributeType.TENSOR: + attrs[attr.name] = numpy_helper.to_array(attr.t) + elif attr.type == onnx.AttributeProto.AttributeType.STRING: + attrs[attr.name] = str(attr.s) + elif attr.type == onnx.AttributeProto.AttributeType.STRINGS: + attrs[attr.name] = tuple([str(x) for x in attr.strings]) + else: + raise Exception('ATTR Type [{}] Not Supported!'.format(attr.type)) + return attrs + + +class BaseQuantizeExportor(): + + optimizer = ONNXOptimUtils + + def __init__(self, onnx_model, export_path) -> None: + + if isinstance(onnx_model, str): + self.onnx_model = onnx.load(onnx_model) + elif isinstance(onnx_model, onnx.ModelProto): + self.onnx_model = onnx_model + else: + raise TypeError + + self.export_path = export_path + self._init_mappings_from_onnx(self.onnx_model) + + self.optimizer.remove_fake_pad_op(self.onnx_model, self.name2data, + self.input2node, self.output2node) + + self._remap_input_and_node() + self._remap_output_and_node() + + @property + def graph(self): + """The onnx model's graph.""" + return self.onnx_model.graph + + def _init_mappings_from_onnx(self, onnx_model): + """Build necessary mappings in a onnx model.""" + + self.input2node = self.optimizer.map_input_and_node(onnx_model) + self.output2node = self.optimizer.map_output_and_node(onnx_model) + self.name2data = self.optimizer.map_name_and_data(onnx_model) + + # todo: maybe useless + # self.name2init = self.optimizer.map_name_and_initializer(onnx_model) + + def _remap_input_and_node(self): + """Rebuild the mapping from input name to a (node, input index) + tuple.""" + self.input2node = self.optimizer.map_input_and_node(self.onnx_model) + + def _remap_output_and_node(self): + """Rebuild the mapping from a node's output name to this node.""" + self.output2node = self.optimizer.map_output_and_node(self.onnx_model) + + def parse_qparams(self, node: onnx.NodeProto): + """Parse the quantize-related parameters based on a node.""" + tensor_name, scale, zero_point = node.input[:3] + + scale, zero_point = self.name2data[scale], self.name2data[zero_point] + if len(node.input) > 3: + qmin, qmax = node.input[-2:] + qmin, qmax = self.name2data[qmin], self.name2data[qmax] + elif len(node.attribute) > 0: + qparams = _parse_attrs(node.attribute) + qmin = qparams['quant_min'] + qmax = qparams['quant_max'] + else: + print_log(f'qmin and qmax are not found for <{node.name}>!') + qmax = qmin = None + return tensor_name, scale, zero_point, qmin, qmax + + def collect_symbolic_nodes(self, onnx_model: onnx.ModelProto): + """Collect all the fakequant nodes from a onnx model.""" + symbolic_nodes = list() + for node in onnx_model.graph.node: + if node.op_type in ALL_FAKEQUANTIZER: + symbolic_nodes.append(node) + return symbolic_nodes + + def _get_constant_inputs(self, node: onnx.NodeProto): + """Get the constant input node for the current node.""" + constant_nodes = list() + output2node = self.output2node + for inp in node.input: + if inp in output2node and output2node[inp].op_type == 'Constant': + cnode = output2node[inp] + + constant_nodes.append(cnode) + return constant_nodes + + def _collect_symbolic_constant_inputs(self, symbolic_nodes: List): + """Collect these constant nodes which is the input of all the symbolic + node.""" + + collected_constant_names = set() + constant_inputs = list() + for node in symbolic_nodes: + constant_inputs = self._get_constant_inputs(node) + for constant in constant_inputs: + if constant.name in collected_constant_names: + continue + constant_inputs.append(constant) + collected_constant_names.add(constant.name) + return constant_inputs + + def _remove_symbolic_related_from_onnx(self, symbolic_nodes: List, + symbolic_constant_inputs: List): + """Remove these out of date fakequant nodes and theirs constant input + nodes.""" + for node in symbolic_nodes: + self.onnx_model.graph.node.remove(node) + + # Remove symbolic related constant nodes. The constant node which is + # only used by those symbolic nodes can be removed. + + def _is_standalone_constant_node(constant): + for node in self.onnx_model.graph.node: + for input_name in node.input: + # A constant node always has one output. + if input_name == constant.output[0]: + return False + return True + + for constant in symbolic_constant_inputs: + if _is_standalone_constant_node(constant): + self.onnx_model.graph.node.remove(constant) + + def export(self): + """Export end to end onnx model.""" + # todo: is it a abstract method? + raise NotImplementedError diff --git a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py new file mode 100644 index 000000000..6d0df5d36 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List + +import numpy as np +import onnx +from google.protobuf.internal.containers import RepeatedScalarFieldContainer +from onnx import helper, numpy_helper + +from .base_quantize_exporter import BaseQuantizeExportor + + +class OpenVinoQuantizeExportor(BaseQuantizeExportor): + + def __init__(self, onnx_model, export_path) -> None: + super().__init__(onnx_model, export_path) + + def _build_backend_node_from_symbolic(self, node: onnx.NodeProto, + tensor_name: str, qmin: np.ndarray, + qmax: np.ndarray): + """Build new onnx nodes which can be deployed to the specific backend. + + These nodes will be used to replace those symbolic nodes in the + original onnx model. + """ + qmax = int(qmax) + qmin = int(qmin) + levels = qmax - qmin + 1 + # adjust weight levels + # if levels == 128: + # levels = 256 + # qmax = qmax * 2 + 1 + # qmin = qmin * 2 + output_name = node.output[0] + # Create a node (FakeQuantize) + keys = ['input_min', 'input_max', 'output_min', 'output_max'] + input_names = [f'{tensor_name}_{key}' for key in keys] + backend_node = helper.make_node( + 'FakeQuantize', # node name + [tensor_name, *input_names], # inputs + [output_name], # outputs + levels=levels, # Attributes + domain='org.openvinotoolkit', + name=node.name) + return backend_node + + def _build_backend_initializer(self, + names: RepeatedScalarFieldContainer[str], + scale: np.ndarray, zero_point: np.ndarray, + qmin: np.ndarray, qmax: np.ndarray, + shape: List[int]): + """Build onnx initializers which can be deployed to specific + backend.""" + + scale = np.abs(np.asarray(scale, dtype=np.float64).reshape(-1)) + zero_point = np.clip( + np.asarray(np.round(zero_point), dtype=np.int32).reshape(-1), + a_min=qmin, + a_max=qmax) + + qrange = float(qmax - qmin) + input_range = scale * qrange + input_high = (qmax - zero_point).astype( + np.float64) * input_range / qrange + input_low = input_high - input_range + input_low_size = input_low.size + + if input_low_size != 1: + input_low = input_low.reshape(*shape) + input_high = input_high.reshape(*shape) + + input_low = input_low.astype(np.float32) + input_high = input_high.astype(np.float32) + + initializers = list() + for init_name, value_tensor in zip( + names, [input_low, input_high, input_low, input_high]): + init = numpy_helper.from_array(value_tensor) + init.name = init_name + initializers.append(init) + return initializers + + def build_backend_nodes_and_initializers(self, symbolic_nodes: List): + """Build new onnx nodes and initializers which can be deployed to + specific backend.""" + backend_nodes = list() + backend_initializers = list() + for node in symbolic_nodes: + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams( + node) + new_node = self._build_backend_node_from_symbolic( + node, tensor_name, qmin, qmax) + backend_nodes.append(new_node) + + try: + # If the successor node (such as a conv node) has weight, + # we need get the length of the weight's shape. And ensure + # the length of the weight's shape and the new node's + # input shape (such as input_low and input_high) is the same. + next_node = self.input2node[node.output[0]][0][0] + # node for save weights + fake_node = self.output2node[next_node.input[1]] + tensor = self.name2data[fake_node.input[0]] + shape_length = len(tensor.shape) + new_shape = [-1] + [1] * (shape_length - 1) + except Exception: + new_shape = [-1] + + # The first element of new_node.input is the tensor name. + new_init_names = new_node.input[1:] + new_initializers = self._build_backend_initializer( + new_init_names, scale, zero_point, qmin, qmax, new_shape) + backend_initializers.extend(new_initializers) + return backend_nodes, backend_initializers + + def _insert_initializers_to_onnx(self, initializers: List): + """Insert onnx initializers to the onnx graph.""" + inserted_init_names = set() + for init in initializers: + if init.name in inserted_init_names: + continue + + self.onnx_model.graph.initializer.append(init) + inserted_init_names.add(init.name) + + def _replace_symbolic_related(self): + """Replacing symbolic related nodes and initializers in the original + onnx model with new nodes and initializers that can be deployed to the + specific backend.""" + + symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) + + collect_func = self._collect_symbolic_constant_inputs + # Usually different activation fakequants share the same constant + # input, and different weight fakequants share the same constant input. + symbolic_constant_inputs = collect_func(symbolic_nodes) + + build_func = self.build_backend_nodes_and_initializers + new_nodes, new_initializers = build_func(symbolic_nodes) + + self._insert_initializers_to_onnx(new_initializers) + + self._remove_symbolic_related_from_onnx(symbolic_nodes, + symbolic_constant_inputs) + + self.onnx_model.graph.node.extend(new_nodes) + self.optimizer.optimize(self.onnx_model) + + def export(self): + """Export end to end onnx model.""" + self._replace_symbolic_related() + onnx.save(self.onnx_model, self.export_path) diff --git a/mmrazor/models/quantizers/exporters/optim_utils.py b/mmrazor/models/quantizers/exporters/optim_utils.py new file mode 100644 index 000000000..62b348d1c --- /dev/null +++ b/mmrazor/models/quantizers/exporters/optim_utils.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional + +import onnx +from mmengine import print_log +from onnx import numpy_helper + + +class ONNXOptimUtils(): + + @classmethod + def map_name_and_data(cls, onnx_model: onnx.ModelProto): + """Build the mapping from a data's name to the data itself.""" + params = {} + for init in onnx_model.graph.initializer: + params[init.name] = numpy_helper.to_array(init) + for node in onnx_model.graph.node: + # If two zero_points are identity, one is a reference to the other + # after optimized by onnx. + if node.op_type == 'Identity' and len(node.input) == 1 and \ + node.input[0] in params: + params[node.output[0]] = copy.deepcopy(params[node.input[0]]) + if node.op_type == 'Constant': + for attr in node.attribute: + if attr.name == 'value': + params[node.output[0]] = numpy_helper.to_array(attr.t) + return params + + @classmethod + def map_name_and_initializer(cls, + onnx_model: onnx.ModelProto, + allow_redundant=True): + """Build the mapping from a initializer's output name to this + initializer.""" + + initializers = dict() + + for idx, init in enumerate(onnx_model.graph.initializer): + initializers[init.name] = (init, idx) + + return initializers + + @classmethod + def map_output_and_node(cls, onnx_model: onnx.ModelProto): + """Build the mapping from a node's output name to this node.""" + output2node = dict() + for node in onnx_model.graph.node: + for output_name in node.output: + output2node[output_name] = node + return output2node + + @classmethod + def map_input_and_node(cls, onnx_model: onnx.ModelProto): + """Build the mapping from input name to a (node, input index) tuple.""" + + input2node: Dict[str, List] = dict() + for node in onnx_model.graph.node: + for idx, input_name in enumerate(node.input): + if input_name not in input2node: + input2node[input_name] = [] + input2node[input_name].append([node, idx]) + return input2node + + @classmethod + def get_constant(cls, name, onnx_model): + for node in onnx_model.graph.node: + if node.op_type == 'Constant': + if node.output[0] == name: + return numpy_helper.to_array(node.attribute[0].t).tolist() + + @classmethod + def get_initializer(cls, initializer_name, onnx_model): + return numpy_helper.to_array( + onnx_model.initializer[initializer_name][0]) + + @classmethod + def get_tensor_producer(cls, output_name, output2node): + if output_name not in output2node: + return 'INPUT_TOKEN' + return output2node[output_name] + + @classmethod + def get_tensor_consumer(self, input_name, input2node): + if input_name not in input2node: + return ['OUTPUT_TOKEN'] + return input2node[input_name] + + @classmethod + def remove_node_from_onnx(cls, node: onnx.NodeProto, + onnx_model: onnx.ModelProto): + """Removes a node from node list.""" + onnx_model.graph.node.remove(node) + + @classmethod + def remove_initializer_from_onnx(cls, initializer: onnx.TensorProto, + onnx_model: onnx.ModelProto): + """Inserts the initializer at the specified position.""" + onnx_model.graph.initializer.remove(initializer) + + @classmethod + def remove_fake_pad_op(cls, onnx_model, name2data, inp2node, out2node): + nodes_to_be_removed = [] + for idx, node in enumerate(onnx_model.graph.node): + if node.op_type == 'Pad': + pads = name2data[node.input[1]] + if all([x == 0 for x in pads]): + print_log(f'Remove pad op: <{node.name}>.') + next_nodes = inp2node[node.output[0]] + for next_node, idx in next_nodes: + next_node.input[idx] = node.input[0] + nodes_to_be_removed.append(node) + + for node in nodes_to_be_removed: + onnx_model.graph.node.remove(node) + + @classmethod + def insert_node_to_onnx(cls, + node: onnx.NodeProto, + onnx_model: onnx.ModelProto, + idx: int = 0): + """Inserts the node at the specified position.""" + onnx_model.graph.node.insert(idx, node) + + @classmethod + def find_standalone_nodes(cls, + onnx_model: onnx.ModelProto, + input2node: Optional[Dict] = None, + output2node: Optional[Dict] = None): + """Find unused nodes.""" + + if input2node is None: + input2node = cls.map_input_and_node(onnx_model) + if output2node is None: + output2node = cls.map_output_and_node(onnx_model) + + def _is_standalone_node(node, input2node, output2node): + for input_name in node.input: + if input_name in output2node: + return False + + for out_node in node.output: + if out_node in input2node: + return False + + return True + + standalone_nodes = list() + for node in onnx_model.graph.node: + + if _is_standalone_node(node, input2node, output2node): + standalone_nodes.append(node) + return standalone_nodes + + @classmethod + def find_redundant_initializers(cls, + onnx_model: onnx.ModelProto, + input2node: Optional[Dict] = None): + """Find unused initializers.""" + if input2node is None: + input2node = cls.map_input_and_node(onnx_model) + + initializers = cls.map_name_and_initializer(onnx_model) + redundant_initializers = list() + redundant_set = set() + for name, init_and_idx in initializers.items(): + if name not in input2node and name not in redundant_set: + # init_and_idx[0] is onnx.onnx_ml_pb2.TensorProto + # init_and_idx[1] is a integer index + redundant_initializers.append(init_and_idx[0]) + redundant_set.add(name) + return redundant_initializers + + @classmethod + def topo_sort(cls, + onnx_model: onnx.ModelProto, + initializers: Optional[Dict] = None, + inplace: bool = True): + """Topologically sort the nodes in a directed acyclic graph. + + Note that nodes in a directed acyclic graph may be out of order + after replacing symbolic related nodes with new nodes. + + Args: + onnx_model (onnx.ModelProto): The onnx model to be sorted + topologically. + initializers (Dict | Optional): The mapping from name to + initializers. Default to None. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + """ + + if inplace: + _onnx_model = onnx_model + else: + _onnx_model = copy.deepcopy(onnx_model) + + if initializers is None: + initializers = cls.map_name_and_initializer( + _onnx_model, allow_redundant=True) + + # A node may have multiple outputs. The first output name of a node + # named `/conv/Conv` is `/conv/Conv_output_0` + output_name2node = {} + for node in _onnx_model.graph.node: + for output_name in node.output: + output_name2node[output_name] = node + for node in _onnx_model.graph.input: + output_name2node[node.name] = node + + name2node = {node.name: node for node in _onnx_model.graph.node} + + graph: Dict[str, + List] = {node.name: [] + for node in _onnx_model.graph.node} + for node in _onnx_model.graph.input: + graph[node.name] = [] + + indegree = {node.name: 0 for node in _onnx_model.graph.node} + + # Build graph + for i, node in enumerate(_onnx_model.graph.node): + for input_name in node.input: + if input_name not in initializers: + indegree[node.name] += 1 + prev_node = output_name2node[input_name] + graph[prev_node.name].append(node) + + graph_input = [node.name for node in _onnx_model.graph.input] + root = graph_input.copy() + sorted_nodes = [] + + # There are some nodes whose input are all initializers. + for node_name, in_degree in indegree.items(): + if in_degree == 0: + root.append(node_name) + + while root: + node_name = root.pop() + # There is no intersection between graph_input and + # _onnx_model.graph.node + if node_name not in graph_input: + node = name2node[node_name] + sorted_nodes.append(node) + for next_node in graph[node_name]: + indegree[next_node.name] -= 1 + if indegree[next_node.name] == 0: + root.append(next_node.name) + + num_nodes = len(_onnx_model.graph.node) + if len(sorted_nodes) != num_nodes: + raise RuntimeError('The graph is not a DAG.') + + for _ in range(num_nodes): + _onnx_model.graph.node.pop() + for node in sorted_nodes: + _onnx_model.graph.node.append(node) + + return _onnx_model + + @classmethod + def optimize(cls, onnx_model): + + input2node = cls.map_input_and_node(onnx_model) + output2node = cls.map_output_and_node(onnx_model) + + standalone_nodes = cls.find_standalone_nodes(onnx_model, input2node, + output2node) + for node in standalone_nodes: + cls.remove_node_from_onnx(node, onnx_model) + print_log(f'Remove node {node.name}') + + redundant_inits = cls.find_redundant_initializers( + onnx_model, input2node) + for init in redundant_inits: + cls.remove_initializer_from_onnx(init, onnx_model) + print_log(f'Remove initializer {init.name}') + + sorted_onnx_model = cls.topo_sort(onnx_model) + + return sorted_onnx_model diff --git a/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py new file mode 100644 index 000000000..7d05847c1 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import onnx + +from .base_quantize_exporter import BaseQuantizeExportor + + +class TensorRTExplicitExporter(BaseQuantizeExportor): + + def __init__(self, onnx_model, export_path) -> None: + super().__init__(onnx_model, export_path) + + def _build_backend_node_from_symbolic(self, node): + quantize_linear_node = onnx.helper.make_node( + 'QuantizeLinear', node.input[:3], [node.name + '_quantized_out'], + node.name + '_quantized') + dequantize_linear_node = onnx.helper.make_node( + 'DequantizeLinear', + [node.name + '_quantized_out'] + quantize_linear_node.input[1:3], + node.output, node.name + '_dequantized') + return [quantize_linear_node, dequantize_linear_node] + + def build_backend_nodes(self, symbolic_nodes): + backend_nodes = list() + for node in symbolic_nodes: + _, _, zero_point, qmin, qmax = self.parse_qparams(node) + assert qmax - qmin in ( + 2**8 - 1, 2**8 - + 2), 'Only 8 bit quantization support deployment to ONNX.' + assert not np.any(zero_point != 0), \ + 'This pass is only supposed to be used with TensorRT ' \ + 'Backend which does not support asymmetric quantization.' + new_nodes = self._build_backend_node_from_symbolic(node) + backend_nodes.extend(new_nodes) + return backend_nodes + + def export(self): + symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) + new_nodes = self.build_backend_nodes(symbolic_nodes) + for node in symbolic_nodes: + self.onnx_model.graph.node.remove(node) + self.onnx_model.graph.node.extend(new_nodes) + self.optimizer.optimize(self.onnx_model) + onnx.save(self.onnx_model, self.export_path) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index f24e8d538..c1edb6fe4 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn from mmengine.config import Config try: - from torch.ao.quantization import enable_fake_quant + from torch.ao.quantization import disable_observer, enable_fake_quant from torch.ao.quantization.fx import prepare from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.ao.quantization.qconfig_mapping import ( @@ -16,11 +15,13 @@ from torch.fx.graph_module import GraphModule from torch.nn.intrinsic.qat import modules as qat_fused_modules from torch.nn.qat import modules as qat_modules + from torch.onnx import register_custom_op_symbolic except ImportError: from mmrazor.utils import get_package_placeholder, get_placeholder GraphModule = get_placeholder('torch>=1.13') ObservedGraphModule = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') + disable_observer = get_placeholder('torch>=1.13') prepare = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') @@ -62,6 +63,23 @@ qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d: qat_modules.Linear } + + def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, + quant_min, quant_max): + return g.op('mmrazor::FixedPerChannelAffine', x, scale, zero_point, + ch_axis, quant_min, quant_max) + + register_custom_op_symbolic('::fake_quantize_per_channel_affine', + fake_quantize_per_channel_affine, 11) + + def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, + quant_max): + return g.op('mmrazor::FixedPerTensorAffine', x, scale, zero_point, + quant_min, quant_max) + + register_custom_op_symbolic('::fake_quantize_per_tensor_affine', + fake_quantize_per_tensor_affine, 11) + else: SUPPORT_QAT_MODULES = () MERGE_BN_MAPPINGS = {} @@ -181,6 +199,13 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') + def export_onnx(self, model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], + torch.Tensor], output_path: str, **kwargs): + """Export the onnx model that can be deployed to a native backend.""" + torch.onnx.export(model, args, output_path, **kwargs) + def prepare(self, model, concrete_args=None): """prepare graph to ObservedGraphModule. @@ -223,16 +248,17 @@ def prepare(self, model, concrete_args=None): return prepared - def post_process_weight_fakequant(self, - observed_module: ObservedGraphModule, - keep_fake_quant: bool = False): + def post_process_for_deploy(self, + observed_module: ObservedGraphModule, + keep_w_fake_quant: bool = False): """weight fake-quant for supported QAT modules. Args: observed_module (ObservedGraphModule): Modules after fused and observed. - keep_fake_quant (bool, optional): Bool to determine whether to keep - fake-quant op, depending on the backend. Defaults to False. + keep_w_fake_quant (bool, optional): Bool to determine whether to + keep weight fake-quant op, depending on the backend. Defaults + to False. Note: `post_process_weight_fakequant()` function is necessary that the @@ -259,7 +285,7 @@ def traverse(module): # This is decided by backend type, some backend need # explicitly keep the fake quant structure, others don't. # TODO add deploy doc link - if keep_fake_quant: + if keep_w_fake_quant: for m in float_child.modules(): setattr(m, 'qconfig', self.qconfig.convert()) @@ -276,13 +302,9 @@ def traverse(module): else: traverse(child) - observed_module.apply(enable_fake_quant) traverse(observed_module) - - def prepare_for_mmdeploy(self, model: nn.Module, dummy_input: Tuple, - checkpoint: Optional[str]): - """Prepare model to Observed_model.""" - raise NotImplementedError + observed_module.apply(enable_fake_quant) + observed_module.apply(disable_observer) def del_redundant_fakequant(self, prepared: GraphModule): """delete redundant fakequant op in prepared model. diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index f8a25bd56..94658afbd 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Optional, Tuple, Union import torch @@ -46,32 +46,26 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') - def prepare_for_mmdeploy(self, - model: torch.nn.Module, - dummy_input: Tuple = (1, 3, 224, 224), - checkpoint: Optional[str] = None): - """Prepare for deploy to the backend with mmdeploy, which will be used - in mmdeploy, and usually includes as follows: - - 1. prepare for the float model rewritten by mmdeploy. - 2. load checkpoint consists of float weight and quantized params in - mmrazor. - 3. post process weight fakequant for exporting .onnx that meet - the backend's requirement. - """ - self.convert_batchnorm2d(model) - observed_model = self.prepare(model) - if dummy_input is not None: - observed_model(torch.randn(dummy_input)) - if checkpoint is not None: - observed_model.load_state_dict( - torch.load(checkpoint)['state_dict']) - self.post_process_weight_fakequant( - observed_model, keep_fake_quant=True) - - observed_model.apply(disable_observer) - - return observed_model + def export_onnx(self, + model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], torch.Tensor], + output_path: str, + opset_version: Optional[int] = 11, + **kwargs): + """Export the onnx model that can be deployed to OpenVino backend.""" + + symbolic_output_path = f'symbolic_{output_path}' + torch.onnx.export( + model, + args, + symbolic_output_path, + opset_version=opset_version, + **kwargs) + + from .exporters import OpenVinoQuantizeExportor + exporter = OpenVinoQuantizeExportor(symbolic_output_path, output_path) + exporter.export() @property def module_prev_wo_fakequant(self):