Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,6 +2437,17 @@ def func(x, x_new_size_):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})

@check_tf_min_version("2.0", "Results are slightly different in tf1")
@check_opset_min_version(11, "resize bicubic")
def test_resize_bicubic(self):
x_shape = [1, 15, 20, 2]
new_size_val = np.array([30, 40], dtype=np.int32)
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
def func(x, new_size):
y = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BICUBIC)
return tf.identity(y, name=_TFOUTPUT)
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: new_size_val}, rtol=1e-6, atol=1e-5)

@check_opset_min_version(10, "resize scale can less than 1")
def test_resize_nearest_neighbor2(self):
x_shape = [1, 300, 20, 2]
Expand Down
25 changes: 20 additions & 5 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,12 @@ def version_13(cls, ctx, node, **kwargs):
cls.any_version_after11(13, ctx, node, **kwargs)


@tf_op(["ResizeBilinear", "ResizeNearestNeighbor"])
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor", "ResizeBicubic"])
class Resize:
@classmethod
def version_7(cls, ctx, node, **kwargs):
utils.make_sure(node.type != "ResizeBicubic", "Opset 11 is required for bicubic interpolation for node %s",
node.name)
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
node.type = "Upsample"
shape = ctx.get_shape(node.input[0])
Expand Down Expand Up @@ -1009,7 +1011,16 @@ def version_10(cls, ctx, node, **kwargs):

@classmethod
def version_11(cls, ctx, node, **kwargs):
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
cubic_coeff_a = None
exclude_outside = False
if node.type == "ResizeBilinear":
mode = "linear"
elif node.type == "ResizeBicubic":
mode = "cubic"
cubic_coeff_a = -0.5
exclude_outside = True
else:
mode = "nearest"
roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array([0]).astype(np.int64))
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(np.int64))
Expand All @@ -1035,9 +1046,11 @@ def version_11(cls, ctx, node, **kwargs):
nearest_mode = "round_prefer_ceil"
else:
transformation_mode = "half_pixel"
resize = ctx.make_node("Resize", resize_inputs,
attr={"mode": mode, "nearest_mode": nearest_mode,
"coordinate_transformation_mode": transformation_mode})
attr = {"mode": mode, "nearest_mode": nearest_mode, "coordinate_transformation_mode": transformation_mode,
"exclude_outside": exclude_outside}
if cubic_coeff_a is not None:
attr["cubic_coeff_a"] = cubic_coeff_a
resize = ctx.make_node("Resize", resize_inputs, attr=attr)
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
Expand All @@ -1050,6 +1063,8 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
# wants the input to be NHWC - adjust target_shape to this.
utils.make_sure(node.type != "ResizeBicubic", "Opset 11 is required for bicubic interpolation for node %s",
node.name)
mode = "linear" if node.type == "ResizeBilinear" else "nearest"

# because onnxruntime only supports to scale the last two dims so transpose is inserted
Expand Down