From c24cd429c4ab47d2b057da8174788c39c60f8760 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Sun, 15 Aug 2021 13:37:09 -0700 Subject: [PATCH] Batchnorm shape inference (#3657) * Fix batchnorm shape inference Signed-off-by: Ganesan Ramalingam * Undo format change of other code Signed-off-by: Ganesan Ramalingam --- onnx/defs/nn/defs.cc | 14 +++++++++++++- onnx/defs/nn/old.cc | 14 +++++++++++++- onnx/test/shape_inference_test.py | 22 ++++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc index c462fabd9cf..51fa6ec5eeb 100644 --- a/onnx/defs/nn/defs.cc +++ b/onnx/defs/nn/defs.cc @@ -1748,9 +1748,21 @@ ONNX_OPERATOR_SET_SCHEMA( propagateShapeAndTypeFromFirstInput(ctx); propagateShapeFromInputToOutput(ctx, 0, 0); + // Inputs 1 to 4 must be of rank 1. + checkInputRank(ctx, 1, 1); + checkInputRank(ctx, 2, 1); + checkInputRank(ctx, 3, 1); + checkInputRank(ctx, 4, 1); + Dim num_channels; - unifyInputDim(ctx, 0, 1, num_channels); + if (hasInputShape(ctx, 0)) { + if (getInputShape(ctx, 0).dim_size() > 1) + unifyInputDim(ctx, 0, 1, num_channels); + else + unifyDim(num_channels, 1); + } + unifyInputDim(ctx, 1, 0, num_channels); unifyInputDim(ctx, 2, 0, num_channels); unifyInputDim(ctx, 3, 0, num_channels); diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc index 7d523ec0fae..4c38745bb6b 100644 --- a/onnx/defs/nn/old.cc +++ b/onnx/defs/nn/old.cc @@ -1888,9 +1888,21 @@ ONNX_OPERATOR_SET_SCHEMA( propagateShapeAndTypeFromFirstInput(ctx); propagateShapeFromInputToOutput(ctx, 0, 0); + // Inputs 1 to 4 must be of rank 1. + checkInputRank(ctx, 1, 1); + checkInputRank(ctx, 2, 1); + checkInputRank(ctx, 3, 1); + checkInputRank(ctx, 4, 1); + Dim num_channels; - unifyInputDim(ctx, 0, 1, num_channels); + if (hasInputShape(ctx, 0)) { + if (getInputShape(ctx, 0).dim_size() > 1) + unifyInputDim(ctx, 0, 1, num_channels); + else + unifyDim(num_channels, 1); + } + unifyInputDim(ctx, 1, 0, num_channels); unifyInputDim(ctx, 2, 0, num_channels); unifyInputDim(ctx, 3, 0, num_channels); diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 9b36e7615ab..cbbe5aa80bf 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -1322,6 +1322,28 @@ def test_batch_norm(self): # type: () -> None []) self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.FLOAT, (3, 4, 5, 6, 7))]) + def test_batch_norm_rank1(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (128,)), # 1-dimensional permitted + ('scale', TensorProto.FLOAT, (1,)), + ('b', TensorProto.FLOAT, (1,)), + ('mean', TensorProto.FLOAT, (1,)), + ('var', TensorProto.FLOAT, (1,))], + [make_node('BatchNormalization', ['x', 'scale', 'b', 'mean', 'var'], ['out'])], + []) + self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.FLOAT, (128,))]) + + def test_batch_norm_invalid(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (128,)), + ('scale', TensorProto.FLOAT, (1, 2)), # invalid rank + ('b', TensorProto.FLOAT, (1,)), + ('mean', TensorProto.FLOAT, (1,)), + ('var', TensorProto.FLOAT, (1,))], + [make_node('BatchNormalization', ['x', 'scale', 'b', 'mean', 'var'], ['out'])], + []) + self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph) + def test_split_negative_axis(self): # type: () -> None graph = self._make_graph( [('x', TensorProto.FLOAT, (2, 4))],