Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved BatchNorm errors when training is true or shape is unknown #1176

Merged
merged 1 commit into from Nov 11, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions tf2onnx/onnx_opset/nn.py
Expand Up @@ -755,6 +755,7 @@ def version_11(cls, ctx, node, **kwargs):
class BatchNorm:
@classmethod
def version_6(cls, ctx, node, **kwargs):
tf_type = node.type
node.type = "BatchNormalization"
# tf inputs: x, scale, bias, mean, variance
# tf outputs: y, batch_mean, batch_var
Expand All @@ -776,14 +777,19 @@ def version_6(cls, ctx, node, **kwargs):
var_shape = ctx.get_shape(node.input[4])
val_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))

if mean_shape != scale_shape:
if node.get_attr_value('is_training', 1) == 1:
logger.warning("Node %s of type %s has is_training set to true, which is not supperted. "
"Please re-save the model with training set to false.",
node.name, tf_type)

if mean_shape != scale_shape and all(d >= 0 for d in scale_shape):
new_mean_value = np.array(np.resize(node.inputs[3].get_tensor_value(as_list=False), scale_shape),
dtype=val_type)
new_mean_node_name = utils.make_name(node.name)
ctx.make_const(new_mean_node_name, new_mean_value)
ctx.replace_input(node, node.input[3], new_mean_node_name, 3)

if var_shape != scale_shape:
if var_shape != scale_shape and all(d >= 0 for d in scale_shape):
new_var_value = np.array(np.resize(node.inputs[4].get_tensor_value(as_list=False), scale_shape),
dtype=val_type)
new_val_node_name = utils.make_name(node.name)
Expand Down