From 769f6562f89a82ef9b80b2e69319c74450c77030 Mon Sep 17 00:00:00 2001 From: Chin Huang Date: Fri, 27 Mar 2020 23:29:32 +0000 Subject: [PATCH 1/2] Add negative paddings support for pad This PR is to add negative paddings support for pad operator according to ONNX spec, negative pads mean to remove elements. It also should address the issue https://github.com/onnx/onnx-tensorflow/issues/584 The current assumption is that paddings can either positive or negative, not mix of them, since ONNX spec does not define such scenario and behavior. --- onnx_tf/handlers/backend/pad.py | 118 +++++++++++++++++++++++++------- test/backend/test_node.py | 15 ++++ 2 files changed, 108 insertions(+), 25 deletions(-) diff --git a/onnx_tf/handlers/backend/pad.py b/onnx_tf/handlers/backend/pad.py index e9e17a1b0..14a737be3 100644 --- a/onnx_tf/handlers/backend/pad.py +++ b/onnx_tf/handlers/backend/pad.py @@ -17,40 +17,108 @@ def _common(cls, node, **kwargs): num_dim = len(tensor_dict[node.inputs[0]].get_shape()) mode = node.attrs.pop("mode", "constant") + def check_positive(pads): + p = tf.greater_equal(pads, tf.zeros((1), dtype=pads.dtype)) + r = tf.reduce_all(p) + return r + + def process_neg_pads(x, paddings): + # process negative paddings differently since TF.pad + # doesn't support negative paddings + + i_shape = tf.shape(x) + cond_less = lambda i1, i2, o1: tf.less(i1, i2) + body_concat = lambda i1, i2, o1: [ + i1 + 1, i2, tf.concat([o1, [i1]], axis=0) + ] + cond_neg_pads = lambda i1, i2, i3, o1: tf.less(i1, i2) + + def _loop_neg_pads(i, i_x, p, result): + # process one dimension at a time + + i_min = tf.negative(tf.gather(p, i * 2)) + i_max = i_shape[i] + tf.gather(p, i * 2 + 1) + t = tf.constant([0]) + _, _, r = tf.while_loop(cond_less, + body_concat, [i_min, i_max, t], + shape_invariants=[ + i_min.get_shape(), + i_max.get_shape(), + tf.TensorShape([None]) + ], + parallel_iterations=1) + gather_indices = tf.gather(r, tf.range(1, tf.size(r))) + result = tf.gather(result, gather_indices) + + # prepare for the next loop + i_min = tf.constant(0) + i_max = i_x + _, _, r = tf.while_loop(cond_less, + body_concat, [i_min, i_max, t], + shape_invariants=[ + i_min.get_shape(), + i_max.get_shape(), + tf.TensorShape([None]) + ], + parallel_iterations=1) + transpose_indices = tf.gather(r, tf.range(1, tf.size(r))) + transpose_indices = tf.roll(transpose_indices, shift=-1, axis=0) + result = tf.transpose(result, transpose_indices) + return i + 1, i_x, p, result + + # tf requires int32 paddings + paddings = tf.cast(paddings, dtype=tf.int32) + i = tf.constant(0) + i_rank = tf.rank(x) + _, _, _, result = tf.while_loop(cond_neg_pads, + _loop_neg_pads, [i, i_rank, paddings, x], + shape_invariants=[ + i.get_shape(), + i_rank.get_shape(), + paddings.get_shape(), + tf.TensorShape(None) + ], + parallel_iterations=1) + return [result] + + def process_pos_pads(x, paddings): + + def _symmetric_pad(i, x): + paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings) + paddings_i = tf.reshape(paddings_i, [num_dim, 2]) + x = tf.pad(x, paddings_i, 'SYMMETRIC') + return i + 1, x + + # tf requires int32 paddings + paddings = tf.cast(tf.transpose(tf.reshape(paddings, [2, num_dim])), + dtype=tf.int32) + + if mode.lower() == "edge": + paddings = tf.reshape(paddings, [-1]) + max_i = tf.reduce_max(paddings) + _, x = tf.while_loop( + lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x], + [tf.TensorShape([]), tf.TensorShape(None)]) + return [x] + + return [ + cls.make_tensor_from_onnx_node( + node, inputs=[x, paddings, mode, constant_values], **kwargs) + ] + if cls.SINCE_VERSION < 11: # for opset 1 and opset 2 paddings = node.attrs.pop("pads", None) - # tf requires int32 paddings - paddings = tf.constant( - np.transpose( - np.array(paddings).reshape([2, num_dim]).astype(np.int32))) constant_values = node.attrs.pop("value", 0.) else: # for opset 11 paddings = tensor_dict[node.inputs[1]] - # tf requires int32 paddings - paddings = tf.cast( - tf.transpose(tf.reshape(paddings, [2, num_dim])), dtype=tf.int32) constant_values = tensor_dict[node.inputs[2]] if len( node.inputs) == 3 else 0 - def _symmetric_pad(i, x): - paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings) - paddings_i = tf.reshape(paddings_i, [num_dim, 2]) - x = tf.pad(x, paddings_i, 'SYMMETRIC') - return i + 1, x - - if mode.lower() == "edge": - paddings = tf.reshape(paddings, [-1]) - max_i = tf.reduce_max(paddings) - _, x = tf.while_loop( - lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x], - [tf.TensorShape([]), tf.TensorShape(None)]) - return [x] - - return [ - cls.make_tensor_from_onnx_node( - node, inputs=[x, paddings, mode, constant_values], **kwargs) - ] + cond = tf.cond(check_positive(paddings), + lambda: process_pos_pads(x, paddings), + lambda: process_neg_pads(x, paddings)) + return cond @classmethod def version_1(cls, node, **kwargs): diff --git a/test/backend/test_node.py b/test/backend/test_node.py index 9ce618683..910debbee 100644 --- a/test/backend/test_node.py +++ b/test/backend/test_node.py @@ -1947,6 +1947,21 @@ def test_pad(self): output = run_node(node_def, [x, pads]) y = np.pad(x, ((1, 1), (1, 1)), mode) np.testing.assert_almost_equal(output["Y"], y) + # negative pads + node_def = helper.make_node("Pad", ["X", "pads"], ["Y"], mode="constant") + pads = np.array([-2, -2, -2, -2], dtype=np.int64) + x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).reshape((3, 4)) + y = x + x = np.pad(x, ((2, 2), (2, 2)), 'constant') + output = run_node(node_def, [x, pads]) + np.testing.assert_almost_equal(output["Y"], y) + # negative pads with 3 dimensions + node_def = helper.make_node("Pad", ["X", "pads"], ["Y"], mode="constant") + pads = np.array([-1, 0, 0, -1, 0, 0], dtype=np.int64) + x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).reshape((2, 3, 2)) + y = np.array([7, 8, 9, 10]).reshape((1, 2, 2)) + output = run_node(node_def, [x, pads]) + np.testing.assert_almost_equal(output["Y"], y) def test_qlinearconv(self): if legacy_opset_pre_ver(10): From 3abe719aacfa26a367fa5a7273c02595da6cc147 Mon Sep 17 00:00:00 2001 From: Chin Huang Date: Sat, 28 Mar 2020 00:05:32 +0000 Subject: [PATCH 2/2] remove unused import --- onnx_tf/handlers/backend/pad.py | 70 +++++++-------------------------- test/backend/test_node.py | 3 +- 2 files changed, 16 insertions(+), 57 deletions(-) diff --git a/onnx_tf/handlers/backend/pad.py b/onnx_tf/handlers/backend/pad.py index 14a737be3..0e000bcc2 100644 --- a/onnx_tf/handlers/backend/pad.py +++ b/onnx_tf/handlers/backend/pad.py @@ -1,4 +1,3 @@ -import numpy as np import tensorflow as tf from onnx_tf.handlers.backend_handler import BackendHandler @@ -23,62 +22,17 @@ def check_positive(pads): return r def process_neg_pads(x, paddings): - # process negative paddings differently since TF.pad + # Process negative paddings differently since tf.pad # doesn't support negative paddings - - i_shape = tf.shape(x) - cond_less = lambda i1, i2, o1: tf.less(i1, i2) - body_concat = lambda i1, i2, o1: [ - i1 + 1, i2, tf.concat([o1, [i1]], axis=0) - ] - cond_neg_pads = lambda i1, i2, i3, o1: tf.less(i1, i2) - - def _loop_neg_pads(i, i_x, p, result): - # process one dimension at a time - - i_min = tf.negative(tf.gather(p, i * 2)) - i_max = i_shape[i] + tf.gather(p, i * 2 + 1) - t = tf.constant([0]) - _, _, r = tf.while_loop(cond_less, - body_concat, [i_min, i_max, t], - shape_invariants=[ - i_min.get_shape(), - i_max.get_shape(), - tf.TensorShape([None]) - ], - parallel_iterations=1) - gather_indices = tf.gather(r, tf.range(1, tf.size(r))) - result = tf.gather(result, gather_indices) - - # prepare for the next loop - i_min = tf.constant(0) - i_max = i_x - _, _, r = tf.while_loop(cond_less, - body_concat, [i_min, i_max, t], - shape_invariants=[ - i_min.get_shape(), - i_max.get_shape(), - tf.TensorShape([None]) - ], - parallel_iterations=1) - transpose_indices = tf.gather(r, tf.range(1, tf.size(r))) - transpose_indices = tf.roll(transpose_indices, shift=-1, axis=0) - result = tf.transpose(result, transpose_indices) - return i + 1, i_x, p, result - - # tf requires int32 paddings - paddings = tf.cast(paddings, dtype=tf.int32) - i = tf.constant(0) - i_rank = tf.rank(x) - _, _, _, result = tf.while_loop(cond_neg_pads, - _loop_neg_pads, [i, i_rank, paddings, x], - shape_invariants=[ - i.get_shape(), - i_rank.get_shape(), - paddings.get_shape(), - tf.TensorShape(None) - ], - parallel_iterations=1) + # The ONNX logic is similar to tf.slice. So we just + # need to compute the begins and sizes for slice op + + i_shape = tf.shape(x, out_type=paddings.dtype) + i_rank = tf.cast(tf.rank(x), paddings.dtype) + begins = tf.negative(tf.gather(paddings, tf.range(i_rank))) + ends = i_shape + tf.gather(paddings, tf.range(i_rank, i_rank*2)) + sizes = ends - begins + result=tf.slice(x, begins, sizes) return [result] def process_pos_pads(x, paddings): @@ -94,6 +48,10 @@ def _symmetric_pad(i, x): dtype=tf.int32) if mode.lower() == "edge": + # Tensorflow doesn't support edge mode so we need to implement the + # np.pad(x, paddings, mode="edge") logic using Tensorflow ops. A + # while loop is used to go through the tf.pad 'SYMMETRIC' mode to pad + # one value at a time for both sides and all dimensions. paddings = tf.reshape(paddings, [-1]) max_i = tf.reduce_max(paddings) _, x = tf.while_loop( diff --git a/test/backend/test_node.py b/test/backend/test_node.py index 910debbee..55b8e0b19 100644 --- a/test/backend/test_node.py +++ b/test/backend/test_node.py @@ -1955,9 +1955,10 @@ def test_pad(self): x = np.pad(x, ((2, 2), (2, 2)), 'constant') output = run_node(node_def, [x, pads]) np.testing.assert_almost_equal(output["Y"], y) + # negative pads with 3 dimensions node_def = helper.make_node("Pad", ["X", "pads"], ["Y"], mode="constant") - pads = np.array([-1, 0, 0, -1, 0, 0], dtype=np.int64) + pads = np.array([-1, 0, 0, 0, -1, 0], dtype=np.int64) x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).reshape((2, 3, 2)) y = np.array([7, 8, 9, 10]).reshape((1, 2, 2)) output = run_node(node_def, [x, pads])