Skip to content

Commit

Permalink
Fix yolo-v2 and aipg-vdcnn tuning failed (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel committed Sep 5, 2022
1 parent f5ce0b4 commit ec84fd6
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 58 deletions.
9 changes: 7 additions & 2 deletions neural_compressor/adaptor/inteltensorflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

ops: &common_ops
int8: ['Conv2D', 'Conv3D', 'DepthwiseConv2dNative', 'FusedBatchNorm', 'FusedBatchNormV2','FusedBatchNormV3',
'MatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
uint8: ['Conv2D', 'Conv3D', 'DepthwiseConv2dNative', 'MatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
'MatMul', 'BatchMatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
uint8: ['Conv2D', 'Conv3D', 'DepthwiseConv2dNative', 'MatMul', 'BatchMatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
bf16: ["Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", "Conv3D", "Conv3DBackpropFilterV2", "Conv3DBackpropInputV2",
"DepthwiseConv2dNative", "DepthwiseConv2dNativeBackpropFilter", "DepthwiseConv2dNativeBackpropInput", "GRUBlockCell",
"AUGRUBlockCell", "MklGRU", "MklAUGRU", "MatMul", "BatchMatMul", "BatchMatMulV2", # allow_list
Expand Down Expand Up @@ -291,10 +291,15 @@
'Dequantize + MatMul + Elu + QuantizeV2',
'Dequantize + MatMul + Tanh + QuantizeV2',
'Dequantize + MatMul + Sigmoid + QuantizeV2',
'Dequantize + BatchMatMul + Mul + QuantizeV2',
'Dequantize + BatchMatMulV2 + Mul + QuantizeV2',
'Dequantize + BatchMatMul + Add + QuantizeV2',
'Dequantize + BatchMatMulV2 + Add + QuantizeV2',
'Dequantize + BatchMatMul + AddV2 + QuantizeV2',
'Dequantize + BatchMatMulV2 + AddV2 + QuantizeV2',
'Dequantize + BatchMatMul + Mul + Add + QuantizeV2',
'Dequantize + BatchMatMulV2 + Mul + Add + QuantizeV2',
'Dequantize + BatchMatMul + Mul + AddV2 + QuantizeV2',
'Dequantize + BatchMatMulV2 + Mul + AddV2 + QuantizeV2',
'Dequantize + Conv3D + AddV2 + AddV2 + Relu + QuantizeV2',
'Dequantize + Conv3D + Add + Relu + QuantizeV2',
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TensorFlowAdaptor(Adaptor):
"AvgPool": "pooling",
"ConcatV2": "concat",
"MatMul": "matmul",
"BatchMatMul": "matmul",
"BatchMatMulV2": "matmul",
"Pad": "pad"
}
Expand Down
9 changes: 1 addition & 8 deletions neural_compressor/adaptor/tf_utils/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from .graph_rewriter.int8.post_quantized_op_cse import PostCseOptimizer
from .graph_rewriter.int8.post_hostconst_converter import PostHostConstConverter
from .graph_rewriter.int8.meta_op_optimizer import MetaInfoChangingMemOpOptimizer
from .graph_rewriter.int8.rnn_convert import QuantizedRNNConverter
from .graph_rewriter.qdq.insert_qdq_pattern import GenerateGraphWithQDQPattern
from neural_compressor.adaptor.tf_utils.graph_rewriter.generic.insert_print_node import InsertPrintMinMaxNode
from .graph_util import GraphRewriterHelper as Helper
Expand Down Expand Up @@ -423,11 +422,6 @@ def quantize(self):
"""
try:
self._quantize_graph()

self._rnn_details = Helper.analysis_rnn_model(self._tmp_graph_def,
bf16_ops=self.bf16_ops,
fp32_ops=self.fp32_ops)
self.quantized_node_info.extend(self._rnn_details.keys())
self.quantized_node_info = [tuple(i) for i in self.quantized_node_info]

if self.fake_quant:
Expand All @@ -452,9 +446,8 @@ def quantize(self):
self.op_wise_config).do_transformation()

for i in self.quantized_node_info:
frame_name = self._rnn_details[i] if i in self._rnn_details else None
sampling_graph_def, output_names = InsertPrintMinMaxNode(
sampling_graph_def, i[0], i[-1], frame_name).do_transformation()
sampling_graph_def, i[0], i[-1]).do_transformation()
output_tensor_names.extend(output_names)
if self.quantized_node_info:
sampling_graph_def.library.CopyFrom(self.model.graph_def.library)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class InsertPrintMinMaxNode(GraphRewriterBase):
"""InsertPrintMinMaxNode Pass for tensorflow sampling.
"""

def __init__(self, model, pre_node_name, post_node_name, frame_name=None):
def __init__(self, model, pre_node_name, post_node_name):
super().__init__(model)
self.pre_node_name = pre_node_name
self.post_node_name = post_node_name
Expand Down Expand Up @@ -143,17 +143,20 @@ def do_transformation(self):
attr_value_pb2.AttrValue.ListValue(type=attr_u))
post_node_names = graph_info[Helper.node_name_from_input(each_node_name)].outputs
if post_node_names:
identity_node0 = None
identity_node1 = None
for post_node_name in post_node_names:
post_node = graph_info[post_node_name].node
if post_node.op == 'FusedBatchNormV3':
identity_node0 = Helper.create_node(
"Identity", min_print_node.name+'_identity', [min_print_node.name])
identity_node0.attr["T"].CopyFrom(src_dt)
identity_node1 = Helper.create_node(
"Identity", max_print_node.name+'_identity', [max_print_node.name])
identity_node1.attr["T"].CopyFrom(src_dt)
if "_print_identity" in \
graph_info[Helper.node_name_from_input(post_node.name)].node.input[0]:
continue
identity_node = Helper.create_node("Identity", post_node.name+'_print_identity',
[graph_info[Helper.node_name_from_input(post_node.name)].node.input[0]])
identity_node.attr["T"].CopyFrom(src_dt)
cur_graph.add_node(identity_node,
graph_info[Helper.node_name_from_input(post_node.name)].node.input[0],
[post_node.name])
identity_node.input.append("^" + min_print_node.name)
identity_node.input.append("^" + max_print_node.name)
else:
post_node.input.append("^" + min_print_node.name)
post_node.input.append("^" + max_print_node.name)
Expand All @@ -165,16 +168,8 @@ def do_transformation(self):
cur_graph.add_node(max_input_node, reshape_input_name, [max_print_node.name])
cur_graph.add_node(min_input_node, reshape_input_name, [min_print_node.name])

if identity_node0 and identity_node1:
cur_graph.add_node(min_print_node, min_input_name, [identity_node0.name])
cur_graph.add_node(max_print_node, max_input_name, [identity_node1.name])
cur_graph.add_node(identity_node0, min_print_node.name, [])
cur_graph.add_node(identity_node1, max_print_node.name, [])
output_names.append(identity_node0.name)
output_names.append(identity_node1.name)
else:
cur_graph.add_node(min_print_node, min_input_name, [])
cur_graph.add_node(max_print_node, max_input_name, [])
cur_graph.add_node(min_print_node, min_input_name, [])
cur_graph.add_node(max_print_node, max_input_name, [])
else:
identity_node0 = Helper.create_node(
"Identity", min_print_node.name+'_identity', [min_print_node.name])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,6 @@ def get_optimized_model(self, itex_mode=False):
self._tmp_graph_def = RenameBatchNormOptimizer(
self._tmp_graph_def).do_transformation()

#TODO we should handle all control ops elegantly not bypass it.
if not self.new_api:
self._tmp_graph_def, excluded_node_names = UpdateEnterOptimizer(
self._tmp_graph_def).do_transformation()

self._tmp_graph_def = ConvertLeakyReluOptimizer(
self._tmp_graph_def).do_transformation()

Expand Down Expand Up @@ -179,8 +174,6 @@ def get_optimized_model(self, itex_mode=False):
if self.new_api:
self._tmp_graph_def = DilatedContraction(
self._tmp_graph_def).do_transformation()
if not self.new_api:
self._excluded_node_names.extend(excluded_node_names)
self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library)

origin_model.graph_def = self._tmp_graph_def
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def do_transformation(self):
self.g_weight.graph = self.g.dump_graph()
self.graph_info = self.g_weight.parse_graph()
target_nodes = self.g_weight.query_fusion_pattern_nodes(
[["Conv2D", "Conv3D", "DepthwiseConv2dNative", "MatMul", "BatchMatMulV2"]])
[["Conv2D", "Conv3D", "DepthwiseConv2dNative", "MatMul", "BatchMatMul", "BatchMatMulV2"]])
for i in target_nodes:
if i[0] not in quantizable_op_names:
continue
Expand Down Expand Up @@ -130,7 +130,8 @@ def do_transformation(self):
def _check_op_list(self, node_type):
op_list = ("ConcatV2", "Conv2D", "Conv3D", "DepthwiseConv2D", "QuantizeV2", "DepthwiseConv2dNative",
"MaxPool", "MaxPool3D", "FusedBatchNormV3", "Requantize", "RequantizePerChannel", "AvgPool", "Pad",
"CropAndResize", "Dequantize", "Mean", "MatMul", "BatchMatMulV2", "FakeQuantWithMinMaxVars")
"CropAndResize", "Dequantize", "Mean", "MatMul", "BatchMatMul",
"BatchMatMulV2", "FakeQuantWithMinMaxVars")
return any([node_type.find(i) != -1 for i in op_list])

def _find_relu_node(self, node):
Expand Down Expand Up @@ -290,6 +291,10 @@ def _insert_qdq_pattern_for_each_input(self, op_name, namespace_prefix,
Helper.set_attr_dtype(max_input_node, "T", dtypes.float32)
Helper.set_attr_dtype(max_input_node, "Tidx", dtypes.int32)
Helper.set_attr_bool(max_input_node, "keep_dims", False)

if "BatchMatMul" in self.graph_info[op_name].node.op:
min_input_node.input.append("^" + input_name)
max_input_node.input.append("^" + input_name)

quant_v2_node = Helper.create_node("QuantizeV2", quantize_input_name,
[input_name, min_input_name, max_input_name])
Expand Down Expand Up @@ -342,7 +347,7 @@ def _insert_qdq_pattern_for_weight_node(self,

# The weight node of BatchMatMul may have no value
if 'value' in weight_node.attr and \
host_op_type in ("Conv2D", "MatMul", "BatchMatMulV2", "Conv3D"):
host_op_type in ("Conv2D", "MatMul", "BatchMatMul", "BatchMatMulV2", "Conv3D"):
float_tensor = tensor_util.MakeNdarray(weight_node.attr["value"].tensor)
if per_channel:
if host_op_type == 'Conv3D':
Expand Down Expand Up @@ -406,6 +411,10 @@ def _insert_qdq_pattern_for_weight_node(self,
dtypes.float32, device="cpu")
max_node = Helper.create_constant_node(max_name, max_value,
dtypes.float32, device="cpu")
if "BatchMatMul" in host_op_type and "BatchMatMul" not in weight_node.op:
min_node.input.append("^" + weight_node.name)
max_node.input.append("^" + weight_node.name)

quant_node = Helper.create_node(
"QuantizeV2", qint8_const_name + '_quant',
[weight_node.name, min_name, max_name])
Expand Down
8 changes: 3 additions & 5 deletions neural_compressor/adaptor/tf_utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,14 +874,12 @@ def generate_int32_bias_for_matmul(bias_tensor, weights_tensor,
input_range * max(abs(max_filter_value), abs(min_filter_value)))
relative_scale = 255 * min_input / (max_input - min_input)
int32_bias = []
axis_value = 0
if weights_tensor.ndim == 2:
if weights_tensor.shape[0] == bias_tensor.shape[0]:
axis_value = 1
for bias_index, value in enumerate(
np.sum(np.array(weights_tensor, dtype=np.int32),
axis=axis_value,
axis=0,
dtype=np.int32)):
if bias_index >= bias_tensor.size:
continue
int32_bias.append(int(np.around(bias_tensor[bias_index] *
bias_scale + value * relative_scale)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ def __init__(self, **kwargs):
'DequantizeMatMulEluQuantizeV2': self.apply_matmul_biasadd_relu_fusion,
'DequantizeMatMulTanhQuantizeV2': self.apply_matmul_biasadd_relu_fusion,
'DequantizeMatMulSigmoidQuantizeV2': self.apply_matmul_biasadd_relu_fusion,
'DequantizeBatchMatMulQuantizeV2': self.apply_batchmatmulv2_fusion,
'DequantizeBatchMatMulV2QuantizeV2': self.apply_batchmatmulv2_fusion,
'DequantizeBatchMatMulMulQuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulV2MulQuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulAddQuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulV2AddQuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulAddV2QuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulV2AddV2QuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulMulAddV2QuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulV2MulAddV2QuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulMulAddQuantizeV2': self.apply_batchmatmulv2_mul_add_fusion,
'DequantizeBatchMatMulV2MulAddQuantizeV2': self.apply_batchmatmulv2_mul_add_fusion
}

Expand Down Expand Up @@ -883,6 +889,10 @@ def _is_match_matmul(self, patterns, qdq_inserted=False):
self.node_name_mapping.keys())[k]].node
if cur_node.name != self.start_node_name:
continue

if not self.performance_only and (cur_node.op == 'BatchMatMulV2' or
cur_node.op == 'BatchMatMul'):
continue

_, normal_inputs = self._get_node_input(cur_node.name)
weight_name = normal_inputs[1]
Expand All @@ -891,22 +901,29 @@ def _is_match_matmul(self, patterns, qdq_inserted=False):
# FIXME We only quantize the MatMul op which second input node type is const. This is a
# workaround for RNN model like LTSM.
parent_node = None
if weight_node.op != 'Const':
if weight_node.input:
parent_node = \
self.node_name_mapping[helper.node_name_from_input(weight_node.input[0])].node
if weight_node.op == 'Enter':
if len(self.node_name_mapping[helper.node_name_from_input(weight_name)].output) > 1:
continue
if parent_node.op == 'Const':
weight_node = parent_node
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)
if np.any(np.isnan(weights_content)):
if cur_node.op == "MatMul":
if weight_node.op != 'Const':
if not self.performance_only:
continue

if weight_node.input:
parent_node = \
self.node_name_mapping[helper.node_name_from_input(weight_node.input[0])].node
if weight_node.op == 'Enter':
if len(self.node_name_mapping[helper.node_name_from_input(weight_name)].output)>1:
continue
else:
continue
if parent_node.op == 'Const':
weight_node = parent_node
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)
if np.any(np.isnan(weights_content)):
continue
else:
continue
else:
weights_content = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)
if np.any(np.isnan(weights_content)):
continue

if cur_node.op == "MatMul":
#TODO Remove below two lines once the TF enabled the QuantizedMatMul while
# transpose_a could be set to True.
if cur_node.attr["transpose_a"].b == True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, input_graph, input_node_names, output_node_names, op_wise_con
self.register_transformer("AvgPool", FuseNodeStartWithPooling)
self.register_transformer("ConcatV2", FuseNodeStartWithConcatV2)
self.register_transformer("MatMul", FuseNodeStartWithMatmul)
self.register_transformer("BatchMatMul", FuseNodeStartWithMatmul)
self.register_transformer("BatchMatMulV2", FuseNodeStartWithMatmul)

def get_quantized_nodes(self):
Expand Down

0 comments on commit ec84fd6

Please sign in to comment.