Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions onnx_tf/handlers/backend/pad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import tensorflow as tf

from onnx_tf.handlers.backend_handler import BackendHandler
Expand All @@ -17,40 +16,67 @@ def _common(cls, node, **kwargs):
num_dim = len(tensor_dict[node.inputs[0]].get_shape())
mode = node.attrs.pop("mode", "constant")

def check_positive(pads):
p = tf.greater_equal(pads, tf.zeros((1), dtype=pads.dtype))
r = tf.reduce_all(p)
return r

def process_neg_pads(x, paddings):
# Process negative paddings differently since tf.pad
# doesn't support negative paddings
# The ONNX logic is similar to tf.slice. So we just
# need to compute the begins and sizes for slice op

i_shape = tf.shape(x, out_type=paddings.dtype)
i_rank = tf.cast(tf.rank(x), paddings.dtype)
begins = tf.negative(tf.gather(paddings, tf.range(i_rank)))
ends = i_shape + tf.gather(paddings, tf.range(i_rank, i_rank*2))
sizes = ends - begins
result=tf.slice(x, begins, sizes)
return [result]

def process_pos_pads(x, paddings):

def _symmetric_pad(i, x):
paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings)
paddings_i = tf.reshape(paddings_i, [num_dim, 2])
x = tf.pad(x, paddings_i, 'SYMMETRIC')
return i + 1, x

# tf requires int32 paddings
paddings = tf.cast(tf.transpose(tf.reshape(paddings, [2, num_dim])),
dtype=tf.int32)

if mode.lower() == "edge":
# Tensorflow doesn't support edge mode so we need to implement the
# np.pad(x, paddings, mode="edge") logic using Tensorflow ops. A
# while loop is used to go through the tf.pad 'SYMMETRIC' mode to pad
# one value at a time for both sides and all dimensions.
paddings = tf.reshape(paddings, [-1])
max_i = tf.reduce_max(paddings)
_, x = tf.while_loop(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Chin, I either forgot or never understood what this part is doing, could you give me some hints about what's going on with the while loop? Some comments may help with future reviews, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The while loop code is not changed in this PR, just added with indentation. It does look a bit strange. I will need to investigate what it is doing and whether it works as expected.

Copy link
Contributor Author

@chinhuang007 chinhuang007 Jun 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, now I understand the while loop is to replace the python implementation of edge mode, np.pad(x, pads, mode="edge"). Basically it goes through the tf.pad 'SYMMETRIC' mode to pad one value at a time for both sides and all dimensions wherever needed. I added some comments and hope it helps.

lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x],
[tf.TensorShape([]), tf.TensorShape(None)])
return [x]

return [
cls.make_tensor_from_onnx_node(
node, inputs=[x, paddings, mode, constant_values], **kwargs)
]

if cls.SINCE_VERSION < 11: # for opset 1 and opset 2
paddings = node.attrs.pop("pads", None)
# tf requires int32 paddings
paddings = tf.constant(
np.transpose(
np.array(paddings).reshape([2, num_dim]).astype(np.int32)))
constant_values = node.attrs.pop("value", 0.)

else: # for opset 11
paddings = tensor_dict[node.inputs[1]]
# tf requires int32 paddings
paddings = tf.cast(
tf.transpose(tf.reshape(paddings, [2, num_dim])), dtype=tf.int32)
constant_values = tensor_dict[node.inputs[2]] if len(
node.inputs) == 3 else 0

def _symmetric_pad(i, x):
paddings_i = tf.map_fn(lambda e: tf.where(i < e, 1, 0), paddings)
paddings_i = tf.reshape(paddings_i, [num_dim, 2])
x = tf.pad(x, paddings_i, 'SYMMETRIC')
return i + 1, x

if mode.lower() == "edge":
paddings = tf.reshape(paddings, [-1])
max_i = tf.reduce_max(paddings)
_, x = tf.while_loop(
lambda i, x: tf.less(i, max_i), _symmetric_pad, [0, x],
[tf.TensorShape([]), tf.TensorShape(None)])
return [x]

return [
cls.make_tensor_from_onnx_node(
node, inputs=[x, paddings, mode, constant_values], **kwargs)
]
cond = tf.cond(check_positive(paddings),
lambda: process_pos_pads(x, paddings),
lambda: process_neg_pads(x, paddings))
return cond

@classmethod
def version_1(cls, node, **kwargs):
Expand Down
16 changes: 16 additions & 0 deletions test/backend/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2674,6 +2674,22 @@ def test_pad(self):
output = run_node(node_def, [x, pads])
y = np.pad(x, ((1, 1), (1, 1)), mode)
np.testing.assert_almost_equal(output["Y"], y)
# negative pads
node_def = helper.make_node("Pad", ["X", "pads"], ["Y"], mode="constant")
pads = np.array([-2, -2, -2, -2], dtype=np.int64)
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).reshape((3, 4))
y = x
x = np.pad(x, ((2, 2), (2, 2)), 'constant')
output = run_node(node_def, [x, pads])
np.testing.assert_almost_equal(output["Y"], y)

# negative pads with 3 dimensions
node_def = helper.make_node("Pad", ["X", "pads"], ["Y"], mode="constant")
pads = np.array([-1, 0, 0, 0, -1, 0], dtype=np.int64)
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).reshape((2, 3, 2))
y = np.array([7, 8, 9, 10]).reshape((1, 2, 2))
output = run_node(node_def, [x, pads])
np.testing.assert_almost_equal(output["Y"], y)

def test_qlinearconv(self):
if legacy_opset_pre_ver(10):
Expand Down