Skip to content

Commit

Permalink
Merge pull request #1106 from onnx/gs/fix-f16-fuse
Browse files Browse the repository at this point in the history
fix bn fuse for fp16
  • Loading branch information
guschmue committed Sep 16, 2020
2 parents 010fad1 + 472770c commit 353e46f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tf2onnx/optimizer/back_to_back_optimizer.py
Expand Up @@ -201,12 +201,13 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
if len(weights.shape) != 4:
return []

bias = 0
# optional bias value
if len(node.inputs) > 2:
if not node.inputs[2].is_const():
return []
bias = node.inputs[2].get_tensor_value(as_list=False)
else:
bias = np.array(0, dtype=weights.dtype)

# scale, offset, mean, var be const, otherwise skip
if False in [node2.inputs[i].is_const() for i in [1, 2, 3, 4]]:
Expand All @@ -228,8 +229,8 @@ def _optimize_conv_batchnorm_fusion(g, node, consumer_nodes):
weights_new = weights * scale_new
weights_new = weights_new.transpose(3, 2, 0, 1)
bias_new = (bias - mean) * scale_new + offset
bias_new_const = g.make_const(node.name + '_bias_fused_bn', bias_new)
weights_new_const = g.make_const(node.name + '_weights_fused_bn', weights_new)
bias_new_const = g.make_const(node.name + '_bias_fused_bn', bias_new.astype(bias.dtype))
weights_new_const = g.make_const(node.name + '_weights_fused_bn', weights_new.astype(weights.dtype))
g.replace_inputs(node, [node.input[0], weights_new_const.output[0], bias_new_const.output[0]])

# fuse conv and bn, delete bn
Expand Down

0 comments on commit 353e46f

Please sign in to comment.