Skip to content

Commit

Permalink
Fix ConvTranspose: enhance attribute check (onnx#3000)
Browse files Browse the repository at this point in the history
* add check for using auto_pad and pads simultaneously

* fix description for auto_pads == SAME_UPPER

* update docs for operator

* fix the old one as well

* add a test

* Revert "fix description for auto_pads == SAME_UPPER"

This reverts commit e75e287.

* Revert "update docs for operator"

This reverts commit 70952c0.

* Revert "fix the old one as well"

This reverts commit 8a0482d.

Co-authored-by: Ashwini Khade <askhade@microsoft.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
jcwchen and askhade committed Sep 22, 2020
1 parent 5def930 commit b1cba0a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
4 changes: 4 additions & 0 deletions onnx/defs/nn/defs.cc
Expand Up @@ -1191,6 +1191,10 @@ void convTransposeShapeInference(InferenceContext& ctx) {
if (pads.size() != n_input_dims * 2) {
fail_shape_inference("Attribute pads has incorrect size");
}
const auto* auto_pad_attr = ctx.getAttribute("auto_pad");
if (nullptr != auto_pad_attr) {
fail_shape_inference("The pads attribute cannot be used simultaneously with auto_pad attribute");
}
} else {
pads.assign(n_input_dims * 2, 0);
const auto* auto_pad_attr = ctx.getAttribute("auto_pad");
Expand Down
10 changes: 10 additions & 0 deletions onnx/test/shape_inference_test.py
Expand Up @@ -1537,6 +1537,16 @@ def test_conv_transpose_with_group_and_output_shape(self): # type: () -> None
[])
self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.FLOAT, (25, 64, 36, 36))])

def test_conv_transpose_with_pads_and_auto_pads(self): # type: () -> None
# This test should fail because pads cannot be used simultaneously with auto_pad
graph = self._make_graph(
[('X', TensorProto.FLOAT, (1, 1, 2, 2)),
('W', TensorProto.FLOAT, (1, 1, 3, 3)),
('B', TensorProto.FLOAT, (1, ))],
[make_node('ConvTranspose', ['X', 'W', 'B'], 'Y', auto_pad="SAME_UPPER", strides=[1, 1], pads=[0, 1, 1, 0])],
[])
self.assertRaises(RuntimeError, onnx.shape_inference.infer_shapes, helper.make_model(graph))

def test_mvn_function_output_shape(self): # type: () -> None
graph = self._make_graph(
[('X', TensorProto.FLOAT, (25, 48, 16, 16))],
Expand Down

0 comments on commit b1cba0a

Please sign in to comment.