Skip to content

Commit

Permalink
Merge pull request #1176 from onnx/tom/ImproveBatchNromErrors
Browse files Browse the repository at this point in the history
Improved BatchNorm errors when training is true or shape is unknown
  • Loading branch information
TomWildenhain-Microsoft committed Nov 11, 2020
2 parents 0a0a099 + e8a66a1 commit 596a334
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ def version_11(cls, ctx, node, **kwargs):
class BatchNorm:
@classmethod
def version_6(cls, ctx, node, **kwargs):
tf_type = node.type
node.type = "BatchNormalization"
# tf inputs: x, scale, bias, mean, variance
# tf outputs: y, batch_mean, batch_var
Expand All @@ -776,14 +777,19 @@ def version_6(cls, ctx, node, **kwargs):
var_shape = ctx.get_shape(node.input[4])
val_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))

if mean_shape != scale_shape:
if node.get_attr_value('is_training', 1) == 1:
logger.warning("Node %s of type %s has is_training set to true, which is not supperted. "
"Please re-save the model with training set to false.",
node.name, tf_type)

if mean_shape != scale_shape and all(d >= 0 for d in scale_shape):
new_mean_value = np.array(np.resize(node.inputs[3].get_tensor_value(as_list=False), scale_shape),
dtype=val_type)
new_mean_node_name = utils.make_name(node.name)
ctx.make_const(new_mean_node_name, new_mean_value)
ctx.replace_input(node, node.input[3], new_mean_node_name, 3)

if var_shape != scale_shape:
if var_shape != scale_shape and all(d >= 0 for d in scale_shape):
new_var_value = np.array(np.resize(node.inputs[4].get_tensor_value(as_list=False), scale_shape),
dtype=val_type)
new_val_node_name = utils.make_name(node.name)
Expand Down

0 comments on commit 596a334

Please sign in to comment.