From 11342d01f27b5b727c2de032beed39c7865ae62f Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 8 Mar 2023 21:29:42 +0800 Subject: [PATCH] WIP --- .../exporters/base_quantize_exporter.py | 3 +- .../exporters/openvino_quantize_exporter.py | 2 + .../quantizers/exporters/optim_utils.py | 110 ++++++++++++++---- .../models/quantizers/openvino_quantizer.py | 72 ++++++------ 4 files changed, 132 insertions(+), 55 deletions(-) diff --git a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py index 7bde5eb72..8485fdc90 100644 --- a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py +++ b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py @@ -122,6 +122,7 @@ def _collect_symbolic_constant_inputs(self, symbolic_nodes): for node in symbolic_nodes: constant_inputs = self._get_constant_inputs(node) for constant in constant_inputs: + print(node.name, constant.name) if constant.name in collected_constant_names: continue constant_inputs.append(constant) @@ -143,7 +144,7 @@ def _remove_symbolic_related_from_onnx(self, symbolic_nodes, if remove and constant.name not in removed: self.onnx_model.graph.node.remove(constant) removed.add(constant.name) - + def export(self, onnx_path): pass diff --git a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py index e0b572086..c67f2fa6a 100644 --- a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py +++ b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py @@ -112,6 +112,8 @@ def _replace_symbolic_related(self): symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) collect_func = self._collect_symbolic_constant_inputs + # Usually different activation fakequants share the same constant + # input, and different weight fakequants share the same constant input. symbolic_constant_inputs = collect_func(symbolic_nodes) build_func = self.build_backend_nodes_and_initializers diff --git a/mmrazor/models/quantizers/exporters/optim_utils.py b/mmrazor/models/quantizers/exporters/optim_utils.py index 081a29e78..96f387e74 100644 --- a/mmrazor/models/quantizers/exporters/optim_utils.py +++ b/mmrazor/models/quantizers/exporters/optim_utils.py @@ -7,7 +7,7 @@ class ONNXOptimUtils(): - + @classmethod def map_name_and_data(cls, onnx_model): params = {} @@ -55,11 +55,12 @@ def get_constant(cls, name, onnx_model): if node.op_type == 'Constant': if node.output[0] == name: return numpy_helper.to_array(node.attribute[0].t).tolist() - + @classmethod def get_initializer(cls, initializer_name, onnx_model): - return numpy_helper.to_array(onnx_model.initializer[initializer_name][0]) - + return numpy_helper.to_array( + onnx_model.initializer[initializer_name][0]) + @classmethod def get_tensor_producer(cls, output_name, output2node): if output_name not in output2node: @@ -72,8 +73,6 @@ def get_tensor_consumer(self, input_name, input2node): return ['OUTPUT_TOKEN'] return input2node[input_name] - - @classmethod def remove_node_from_onnx(cls, node, onnx_model): onnx_model.graph.node.remove(node) @@ -114,20 +113,15 @@ def find_standalone_nodes(cls, output2node = cls.map_output_and_node(onnx_model) def _is_standalone_node(node, input2node, output2node): - standalone = True for input_name in node.input: if input_name in output2node: - standalone = False - break - - if not standalone: - return False + return False for out_node in node.output: if out_node in input2node: - standalone = False + return False - return standalone + return True standalone_nodes = list() for node in onnx_model.graph.node: @@ -146,22 +140,91 @@ def find_redundant_initializers(cls, onnx_model, input2node=None): redundant_set = set() for name, init_and_idx in initializers.items(): if name not in input2node and name not in redundant_set: + # init_and_idx[0] is onnx.onnx_ml_pb2.TensorProto + # init_and_idx[1] is a integer index redundant_initializers.append(init_and_idx[0]) redundant_set.add(name) return redundant_initializers + @classmethod + def topo_sort2(cls, onnx_model, initializers=None, inplace=True): + + if inplace: + _onnx_model = onnx_model + else: + _onnx_model = copy.deepcopy(onnx_model) + + if initializers is None: + initializers = cls.map_name_and_initializer( + _onnx_model, allow_redundant=True) + + # A node may have multiple outputs. The first output name of a node + # named `/conv/Conv` is `/conv/Conv_output_0` + output_name2node = {} + for node in _onnx_model.graph.node: + for output_name in node.output: + output_name2node[output_name] = node + for node in _onnx_model.graph.input: + output_name2node[node.name] = node + + name2node = {node.name: node for node in _onnx_model.graph.node} + + graph = {node.name: [] for node in _onnx_model.graph.node} + for node in _onnx_model.graph.input: + graph[node.name] = [] + + indegree = {node.name: 0 for node in _onnx_model.graph.node} + + # Build graph + for i, node in enumerate(_onnx_model.graph.node): + for input_name in node.input: + if input_name not in initializers: + indegree[node.name] += 1 + prev_node = output_name2node[input_name] + graph[prev_node.name].append(node) + + graph_input = [node.name for node in _onnx_model.graph.input] + root = graph_input.copy() + sorted_nodes = [] + + # There are some nodes whose input are all initializers. + for node_name, in_degree in indegree.items(): + if in_degree == 0: + root.append(node_name) + + while root: + node_name = root.pop() + # There is no intersection between graph_input and + # _onnx_model.graph.node + if node_name not in graph_input: + node = name2node[node_name] + sorted_nodes.append(node) + for next_node in graph[node_name]: + indegree[next_node.name] -= 1 + if indegree[next_node.name] == 0: + root.append(next_node.name) + + num_nodes = len(_onnx_model.graph.node) + if len(sorted_nodes) != num_nodes: + raise RuntimeError('The graph is not a DAG.') + + for _ in range(num_nodes): + _onnx_model.graph.node.pop() + for node in sorted_nodes: + _onnx_model.graph.node.append(node) + + return _onnx_model + @classmethod def topo_sort(cls, onnx_model, initializers=None, inplace=True): def _is_zero_in_degree(node, exist_inputs, initializers): - flag = True for input_name in node.input: if (input_name not in exist_inputs and input_name not in initializers): - flag = False - break + return False - return flag + return True if inplace: _onnx_model = onnx_model @@ -176,7 +239,7 @@ def _is_zero_in_degree(node, exist_inputs, initializers): num_nodes = len(_onnx_model.graph.node) sorted_nodes = list() - + while len(sorted_nodes) < num_nodes: find_new_node = False for i in range(num_nodes): @@ -204,12 +267,17 @@ def _is_zero_in_degree(node, exist_inputs, initializers): @classmethod def optimize(cls, onnx_model): - standalone_nodes = cls.find_standalone_nodes(onnx_model) + input2node = cls.map_input_and_node(onnx_model) + output2node = cls.map_output_and_node(onnx_model) + + standalone_nodes = cls.find_standalone_nodes(onnx_model, input2node, + output2node) for node in standalone_nodes: cls.remove_node_from_onnx(node, onnx_model) print_log(f'Remove node {node.name}') - redundant_inits = cls.find_redundant_initializers(onnx_model) + redundant_inits = cls.find_redundant_initializers( + onnx_model, input2node) for init in redundant_inits: cls.remove_initializer_from_onnx(init, onnx_model) print_log(f'Remove initializer {init.name}') diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 6acf66093..3c4ddd7cb 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -10,10 +10,8 @@ disable_observer = get_placeholder('torch>=1.13') from mmrazor.registry import MODELS -from .native_quantizer import NativeQuantizer from ..algorithms.quantization import MMArchitectureQuant - - +from .native_quantizer import NativeQuantizer @MODELS.register_module() @@ -48,33 +46,31 @@ def support_a_modes(self): """Supported quantization modes for activation about per_tensor or per_channel.""" return ('per_tensor') - - def export_onnx(self, model, args, output_path, export_params,input_names, output_names, opset_version, dynamic_axes, keep_initializers_as_inputs, verbose): - + + def export_onnx(self, model, args, output_path, export_params, input_names, + output_names, opset_version, dynamic_axes, + keep_initializers_as_inputs, verbose): + symbolic_output_path = f'{output_path}.symbolic' torch.onnx.export( - model, - args, - symbolic_output_path, - export_params=export_params, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - verbose=verbose) - - from .exporters import OpenVinoQuantizeExportor + model, + args, + symbolic_output_path, + export_params=export_params, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose) + + from .exporters import OpenVinoQuantizeExportor exporter = OpenVinoQuantizeExportor(symbolic_output_path, output_path) exporter.export() - - - - - + def post_process_for_mmdeploy(self, - model: MMArchitectureQuant, - dummy_input: Tuple = (1, 3, 224, 224)): + model: MMArchitectureQuant, + dummy_input: Tuple = (1, 3, 224, 224)): """Prepare for deploy to the backend with mmdeploy, which will be used in mmdeploy, and usually includes as follows: @@ -84,24 +80,34 @@ def post_process_for_mmdeploy(self, 3. post process weight fakequant for exporting .onnx that meet the backend's requirement. """ - - quantized_state_dict = model.qmodels['predict'].state_dict() + + quantized_state_dict = model.qmodels['tensor'].state_dict() fp32_model = model.architecture self.convert_batchnorm2d(fp32_model) - observed_model = self.prepare(fp32_model) - + observed_model = self.prepare(fp32_model, {'mode': 'tensor'}) + if dummy_input is not None: observed_model(torch.randn(dummy_input)) - + observed_model.load_state_dict(quantized_state_dict) - - self.post_process_weight_fakequant( - observed_model, keep_fake_quant=True) + + self.post_process_for_deploy(observed_model, keep_fake_quant=True) observed_model.apply(disable_observer) return observed_model + def post_process_for_torchvision(self, + model: MMArchitectureQuant, + dummy_input: Tuple = (1, 3, 224, 224)): + self.convert_batchnorm2d(model) + observed_model = self.prepare(model) + if dummy_input is not None: + observed_model(torch.randn(dummy_input)) + self.post_process_for_deploy(observed_model, keep_fake_quant=True) + observed_model.apply(disable_observer) + return observed_model + @property def module_prev_wo_fakequant(self): """Configurate the modules that their previous nodes are redundant