Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Mar 8, 2023
1 parent 0f40221 commit 11342d0
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _collect_symbolic_constant_inputs(self, symbolic_nodes):
for node in symbolic_nodes:
constant_inputs = self._get_constant_inputs(node)
for constant in constant_inputs:
print(node.name, constant.name)
if constant.name in collected_constant_names:
continue
constant_inputs.append(constant)
Expand All @@ -143,7 +144,7 @@ def _remove_symbolic_related_from_onnx(self, symbolic_nodes,
if remove and constant.name not in removed:
self.onnx_model.graph.node.remove(constant)
removed.add(constant.name)

def export(self, onnx_path):
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def _replace_symbolic_related(self):
symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model)

collect_func = self._collect_symbolic_constant_inputs
# Usually different activation fakequants share the same constant
# input, and different weight fakequants share the same constant input.
symbolic_constant_inputs = collect_func(symbolic_nodes)

build_func = self.build_backend_nodes_and_initializers
Expand Down
110 changes: 89 additions & 21 deletions mmrazor/models/quantizers/exporters/optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class ONNXOptimUtils():

@classmethod
def map_name_and_data(cls, onnx_model):
params = {}
Expand Down Expand Up @@ -55,11 +55,12 @@ def get_constant(cls, name, onnx_model):
if node.op_type == 'Constant':
if node.output[0] == name:
return numpy_helper.to_array(node.attribute[0].t).tolist()

@classmethod
def get_initializer(cls, initializer_name, onnx_model):
return numpy_helper.to_array(onnx_model.initializer[initializer_name][0])

return numpy_helper.to_array(
onnx_model.initializer[initializer_name][0])

@classmethod
def get_tensor_producer(cls, output_name, output2node):
if output_name not in output2node:
Expand All @@ -72,8 +73,6 @@ def get_tensor_consumer(self, input_name, input2node):
return ['OUTPUT_TOKEN']
return input2node[input_name]



@classmethod
def remove_node_from_onnx(cls, node, onnx_model):
onnx_model.graph.node.remove(node)
Expand Down Expand Up @@ -114,20 +113,15 @@ def find_standalone_nodes(cls,
output2node = cls.map_output_and_node(onnx_model)

def _is_standalone_node(node, input2node, output2node):
standalone = True
for input_name in node.input:
if input_name in output2node:
standalone = False
break

if not standalone:
return False
return False

for out_node in node.output:
if out_node in input2node:
standalone = False
return False

return standalone
return True

standalone_nodes = list()
for node in onnx_model.graph.node:
Expand All @@ -146,22 +140,91 @@ def find_redundant_initializers(cls, onnx_model, input2node=None):
redundant_set = set()
for name, init_and_idx in initializers.items():
if name not in input2node and name not in redundant_set:
# init_and_idx[0] is onnx.onnx_ml_pb2.TensorProto
# init_and_idx[1] is a integer index
redundant_initializers.append(init_and_idx[0])
redundant_set.add(name)
return redundant_initializers

@classmethod
def topo_sort2(cls, onnx_model, initializers=None, inplace=True):

if inplace:
_onnx_model = onnx_model
else:
_onnx_model = copy.deepcopy(onnx_model)

if initializers is None:
initializers = cls.map_name_and_initializer(
_onnx_model, allow_redundant=True)

# A node may have multiple outputs. The first output name of a node
# named `/conv/Conv` is `/conv/Conv_output_0`
output_name2node = {}
for node in _onnx_model.graph.node:
for output_name in node.output:
output_name2node[output_name] = node
for node in _onnx_model.graph.input:
output_name2node[node.name] = node

name2node = {node.name: node for node in _onnx_model.graph.node}

graph = {node.name: [] for node in _onnx_model.graph.node}
for node in _onnx_model.graph.input:
graph[node.name] = []

indegree = {node.name: 0 for node in _onnx_model.graph.node}

# Build graph
for i, node in enumerate(_onnx_model.graph.node):
for input_name in node.input:
if input_name not in initializers:
indegree[node.name] += 1
prev_node = output_name2node[input_name]
graph[prev_node.name].append(node)

graph_input = [node.name for node in _onnx_model.graph.input]
root = graph_input.copy()
sorted_nodes = []

# There are some nodes whose input are all initializers.
for node_name, in_degree in indegree.items():
if in_degree == 0:
root.append(node_name)

while root:
node_name = root.pop()
# There is no intersection between graph_input and
# _onnx_model.graph.node
if node_name not in graph_input:
node = name2node[node_name]
sorted_nodes.append(node)
for next_node in graph[node_name]:
indegree[next_node.name] -= 1
if indegree[next_node.name] == 0:
root.append(next_node.name)

num_nodes = len(_onnx_model.graph.node)
if len(sorted_nodes) != num_nodes:
raise RuntimeError('The graph is not a DAG.')

for _ in range(num_nodes):
_onnx_model.graph.node.pop()
for node in sorted_nodes:
_onnx_model.graph.node.append(node)

return _onnx_model

@classmethod
def topo_sort(cls, onnx_model, initializers=None, inplace=True):

def _is_zero_in_degree(node, exist_inputs, initializers):
flag = True
for input_name in node.input:
if (input_name not in exist_inputs
and input_name not in initializers):
flag = False
break
return False

return flag
return True

if inplace:
_onnx_model = onnx_model
Expand All @@ -176,7 +239,7 @@ def _is_zero_in_degree(node, exist_inputs, initializers):
num_nodes = len(_onnx_model.graph.node)

sorted_nodes = list()

while len(sorted_nodes) < num_nodes:
find_new_node = False
for i in range(num_nodes):
Expand Down Expand Up @@ -204,12 +267,17 @@ def _is_zero_in_degree(node, exist_inputs, initializers):
@classmethod
def optimize(cls, onnx_model):

standalone_nodes = cls.find_standalone_nodes(onnx_model)
input2node = cls.map_input_and_node(onnx_model)
output2node = cls.map_output_and_node(onnx_model)

standalone_nodes = cls.find_standalone_nodes(onnx_model, input2node,
output2node)
for node in standalone_nodes:
cls.remove_node_from_onnx(node, onnx_model)
print_log(f'Remove node {node.name}')

redundant_inits = cls.find_redundant_initializers(onnx_model)
redundant_inits = cls.find_redundant_initializers(
onnx_model, input2node)
for init in redundant_inits:
cls.remove_initializer_from_onnx(init, onnx_model)
print_log(f'Remove initializer {init.name}')
Expand Down
72 changes: 39 additions & 33 deletions mmrazor/models/quantizers/openvino_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
disable_observer = get_placeholder('torch>=1.13')

from mmrazor.registry import MODELS
from .native_quantizer import NativeQuantizer
from ..algorithms.quantization import MMArchitectureQuant


from .native_quantizer import NativeQuantizer


@MODELS.register_module()
Expand Down Expand Up @@ -48,33 +46,31 @@ def support_a_modes(self):
"""Supported quantization modes for activation about per_tensor or
per_channel."""
return ('per_tensor')

def export_onnx(self, model, args, output_path, export_params,input_names, output_names, opset_version, dynamic_axes, keep_initializers_as_inputs, verbose):


def export_onnx(self, model, args, output_path, export_params, input_names,
output_names, opset_version, dynamic_axes,
keep_initializers_as_inputs, verbose):

symbolic_output_path = f'{output_path}.symbolic'
torch.onnx.export(
model,
args,
symbolic_output_path,
export_params=export_params,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
verbose=verbose)
from .exporters import OpenVinoQuantizeExportor
model,
args,
symbolic_output_path,
export_params=export_params,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
verbose=verbose)

from .exporters import OpenVinoQuantizeExportor
exporter = OpenVinoQuantizeExportor(symbolic_output_path, output_path)
exporter.export()






def post_process_for_mmdeploy(self,
model: MMArchitectureQuant,
dummy_input: Tuple = (1, 3, 224, 224)):
model: MMArchitectureQuant,
dummy_input: Tuple = (1, 3, 224, 224)):
"""Prepare for deploy to the backend with mmdeploy, which will be used
in mmdeploy, and usually includes as follows:
Expand All @@ -84,24 +80,34 @@ def post_process_for_mmdeploy(self,
3. post process weight fakequant for exporting .onnx that meet
the backend's requirement.
"""
quantized_state_dict = model.qmodels['predict'].state_dict()

quantized_state_dict = model.qmodels['tensor'].state_dict()
fp32_model = model.architecture
self.convert_batchnorm2d(fp32_model)
observed_model = self.prepare(fp32_model)
observed_model = self.prepare(fp32_model, {'mode': 'tensor'})

if dummy_input is not None:
observed_model(torch.randn(dummy_input))

observed_model.load_state_dict(quantized_state_dict)

self.post_process_weight_fakequant(
observed_model, keep_fake_quant=True)

self.post_process_for_deploy(observed_model, keep_fake_quant=True)

observed_model.apply(disable_observer)

return observed_model

def post_process_for_torchvision(self,
model: MMArchitectureQuant,
dummy_input: Tuple = (1, 3, 224, 224)):
self.convert_batchnorm2d(model)
observed_model = self.prepare(model)
if dummy_input is not None:
observed_model(torch.randn(dummy_input))
self.post_process_for_deploy(observed_model, keep_fake_quant=True)
observed_model.apply(disable_observer)
return observed_model

@property
def module_prev_wo_fakequant(self):
"""Configurate the modules that their previous nodes are redundant
Expand Down

0 comments on commit 11342d0

Please sign in to comment.