diff --git a/tests/test_backend.py b/tests/test_backend.py index c6a880fe8..04fd1626b 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -164,12 +164,12 @@ def get_conv_getdata(kind=1): def get_maxpoolwithargmax_getdata(): data = [ - ('SAME', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]), - ('SAME', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]), - ('SAME', [1, 10, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1]), - ('SAME', [1, 10, 5, 1], [1, 4, 4, 1], [1, 1, 1, 1]), - ('VALID', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]), - ('VALID', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]), + ('SAME', [1, 3, 3, 2], [1, 3, 3, 1], [1, 2, 2, 1]), + ('SAME', [2, 5, 5, 3], [1, 4, 4, 1], [1, 2, 2, 1]), + ('SAME', [2, 10, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1]), + ('SAME', [2, 10, 5, 3], [1, 4, 4, 1], [1, 1, 1, 1]), + ('VALID', [2, 3, 3, 3], [1, 3, 3, 1], [1, 2, 2, 1]), + ('VALID', [2, 5, 5, 3], [1, 4, 4, 1], [1, 2, 2, 1]), ] for idx, v in enumerate(data): yield (idx,) + v @@ -3738,13 +3738,41 @@ def func(x): def test_maxpoolwithargmax(self): for p in get_maxpoolwithargmax_getdata(): _, padding, x_shape, ksize, strides = p - x_val = make_xval(x_shape) + x_val = np.random.uniform(0, 10, x_shape) def func(x): mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding) return tf.identity(mp[0], name=_TFOUTPUT), tf.identity(mp[1], name=_TFOUTPUT1) self.logger.debug(str(p)) self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val}) + @check_tf_min_version("1.13") + @check_opset_min_version(11, "MaxPoolWithArgmax") + def test_maxpoolwithargmax_batch_in_index(self): + padding = 'SAME' + x_shape = [2, 10, 5, 3] + ksize = [1, 4, 4, 1] + strides = [1, 1, 1, 1] + x_val = np.random.uniform(0, 10, x_shape) + def func(x): + mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding, include_batch_in_index=True) + return tf.identity(mp[0], name=_TFOUTPUT), tf.identity(mp[1], name=_TFOUTPUT1) + self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val}) + + @check_tf_min_version("1.13") + @check_opset_min_version(11, "MaxPoolWithArgmax") + def test_maxpoolwithargmax_unknown_c(self): + padding = 'SAME' + x_shape = [2, 10, 5, 1] + ksize = [1, 4, 4, 1] + strides = [1, 1, 1, 1] + x_val = np.random.uniform(0, 10, x_shape) + s_val = np.array([2, 10, 5, 4], np.int64) + def func(x, s): + x = tf.broadcast_to(x, s) + mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding, include_batch_in_index=True) + return tf.identity(mp[0], name=_TFOUTPUT), tf.identity(mp[1], name=_TFOUTPUT1) + self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: s_val}) + @check_opset_min_version(10, "Selu") def test_selu(self): x_val = np.random.random_sample([3]).astype(np.float32) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index afb3f5d99..491fada6d 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -694,6 +694,42 @@ def version_8(cls, ctx, node, **kwargs): # The input data_format is NHWC for TF MaxPoolWithArgmax node.set_attr("data_format", "NHWC") + # Convert indices from NCHW to NHWC format + input_shape = ctx.make_node("Shape", [node.input[0]]).output[0] + input_shape_guess = ctx.get_shape(node.input[0]) + n, h, w, c = ctx.make_node("Split", [input_shape], attr={'axis': 0}, output_count=4).output + hw = ctx.make_node("Mul", [h, w]).output[0] + chw = ctx.make_node("Mul", [hw, c]).output[0] + consumers = ctx.find_output_consumers(node.output[1]) + if ctx.opset >= 10: + xy = ctx.make_node("Mod", [node.output[1], hw]).output[0] + else: + xy_div = ctx.make_node("Div", [node.output[1], hw]).output[0] + xy_mul = ctx.make_node("Mul", [xy_div, hw]).output[0] + xy = ctx.make_node("Sub", [node.output[1], xy_mul]).output[0] + xy_scale_c = ctx.make_node("Mul", [xy, c]).output[0] + const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0] + const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0] + if input_shape_guess is not None and input_shape_guess[3] > 0: + c_range_np = np.arange(input_shape_guess[3], dtype=np.int64) + c_range = ctx.make_const(utils.make_name("c_range"), c_range_np).output[0] + else: + utils.make_sure(ctx.opset >= 11, "opset 11 required for MaxPoolWithArgmax with non-const num channels") + c_sq = GraphBuilder(ctx).make_squeeze({'data': c, 'axes': [0]}) + c_range = ctx.make_node("Range", [const_zero, c_sq, const_one]).output[0] + xyc = ctx.make_node("Add", [xy_scale_c, c_range]).output[0] + single_batch = input_shape_guess is not None and input_shape_guess[0] == 1 + if node.get_attr_value('include_batch_in_index', False) and not single_batch: + utils.make_sure(ctx.opset >= 11, "opset 11 required for MaxPoolWithArgmax with include_batch_in_index") + n_sq = GraphBuilder(ctx).make_squeeze({'data': n, 'axes': [0]}) + n_range = ctx.make_node("Range", [const_zero, n_sq, const_one]).output[0] + n_range_unsq = GraphBuilder(ctx).make_unsqueeze({'data': n_range, 'axes': [1, 2, 3]}) + n_range_scale = ctx.make_node("Mul", [n_range_unsq, chw]).output[0] + result = ctx.make_node("Add", [xyc, n_range_scale]).output[0] + else: + result = xyc + ctx.replace_all_inputs(node.output[1], result, ops=consumers) + add_padding(ctx, node, kernel_shape, strides) conv_convert_inputs(ctx, node, with_kernel=False, input_indices=[0], output_indices=[0, 1])