Skip to content

Commit

Permalink
[Operator] Add cubic interpolation to Resize Operator (#22)
Browse files Browse the repository at this point in the history
* WIP: Resize2d cubic; Add pytorch test for Resize (test failed)

* Add draft ver. of Resize2d Cubic

* Fix get pixel bug in resize2d

* Add Greater Than expression

* Fix typo and Resize test case
  • Loading branch information
hjjq committed Nov 16, 2022
1 parent 93fc859 commit 1c6909c
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 11 deletions.
6 changes: 3 additions & 3 deletions python/hidet/graph/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,14 @@ def run(self, inputs: List[Tensor]) -> List[Tensor]:
if roi is not None:
roi = self.tensor2list(roi)
target_size = None
if scales is not None:
if scales is not None and scales.num_elements > 0:
scales = self.tensor2list(scales)
assert len(x.shape) == len(scales)
target_size = [int(a * b) for a, b in zip(x.shape, scales)]
if sizes is not None:
elif sizes is not None and sizes.num_elements > 0:
sizes = self.tensor2list(sizes)
target_size = [int(v) for v in sizes]
if target_size is None:
else:
raise ValueError('Resize operator in onnx must give either scales or sizes.')
if len(x.shape) == 4:
if not (target_size[0] == x.shape[0] and target_size[1] == x.shape[1]):
Expand Down
56 changes: 49 additions & 7 deletions python/hidet/graph/ops/definitions/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, List

from hidet.ir.expr import Expr, if_then_else, convert, cast, And
from hidet.ir.expr import Expr, if_then_else, convert, cast, And, Or
from hidet.ir import primitives as prim
from .utils import Task, Operator, Tensor, TensorNode, compute, input_like

Expand All @@ -14,7 +14,7 @@ def get_origin_index(x: Expr, image_width: int, target_width: int, coordinate_tr
'half_pixel': lambda x: (x + 0.5) * scale - 0.5,
'align_corners': lambda x: x * ((image_width - 1) / (target_width - 1)),
'asymmetric': lambda x: x * scale,
'pytorch_half_pixel': lambda x: (x + 0.5) * scale if target_width > 1 else convert(0.0),
'pytorch_half_pixel': lambda x: (x + 0.5) * scale - 0.5 if target_width > 1 else convert(0.0),
'tf_half_pixel_for_nn': lambda x: (x + 0.5) * scale,
}
if coordinate_transformation_mode not in func_map:
Expand All @@ -28,7 +28,7 @@ def get_origin_index(x: Expr, image_width: int, target_width: int, coordinate_tr

def get_closest_index(x: Expr, rounding_method: str) -> Expr:
func_map = {
'rounding_method': lambda x: cast(prim.round(x), 'int32'),
'round': lambda x: cast(prim.round(x), 'int32'),
'round_prefer_floor': lambda x: cast(prim.ceil(x - 0.5), 'int32'),
'round_prefer_ceil': lambda x: cast(prim.floor(x + 0.5), 'int32'),
'floor': lambda x: cast(prim.floor(x + 1e-5), 'int32'), # add epsilon (1e-5) to prevent gpu rounding error
Expand All @@ -41,15 +41,30 @@ def get_closest_index(x: Expr, rounding_method: str) -> Expr:

def get_2d_pixel(data: TensorNode, n, c, h, w) -> Expr:
height, width = data.const_shape()[2:]
h = prim.max(0, prim.min(height, h))
w = prim.max(0, prim.min(width, w))
h = prim.max(0, prim.min(height - 1, h))
w = prim.max(0, prim.min(width - 1, w))
return data[n, c, h, w]


def linear_interpolate(a, b, ratio):
return a * (1.0 - ratio) + b * ratio


def get_cubic_weights(s, a):
# See equations (4)-(6) in https://ieeexplore.ieee.org/document/1163711
s2 = s * s
s3 = s * s * s
w1 = a * (s3 - 2 * s2 + s)
w2 = (a + 2) * s3 - (3 + a) * s2 + 1
w3 = -(a + 2) * s3 + (3 + 2 * a) * s2 - a * s
w4 = -a * s3 + a * s2
return [w1, w2, w3, w4]


def cubic_interpolate(inputs, weights):
return sum(inputs_i * weights_i for inputs_i, weights_i in zip(inputs, weights))


def resize2d_nchw_compute(
data: TensorNode,
size: List[int],
Expand All @@ -63,6 +78,8 @@ def resize2d_nchw_compute(
): # pylint: disable=unused-argument
image_size = data.const_shape()[2:]
target_size = size
image_height = image_size[0]
image_width = image_size[1]

def fmap(n, c, h, w):
h = get_origin_index(h, image_size[0], target_size[0], coordinate_transformation_mode)
Expand All @@ -81,7 +98,32 @@ def fmap(n, c, h, w):
bottom = linear_interpolate(*pixels[1], w_ratio)
value = linear_interpolate(top, bottom, h_ratio)
elif method == 'cubic':
raise NotImplementedError(method)
h_int = cast(prim.floor(h), 'int32')
w_int = cast(prim.floor(w), 'int32')
h_ratio = h - prim.floor(h)
w_ratio = w - prim.floor(w)
pixels = [[get_2d_pixel(data, n, c, h_int + i - 1, w_int + j - 1) for j in range(4)] for i in range(4)]

weight_w = get_cubic_weights(w_ratio, cubic_alpha)
weight_h = get_cubic_weights(h_ratio, cubic_alpha)
if cubic_exclude:
for i in range(4):
weight_w[i] = if_then_else(
Or.join((w_int - 1 + i) < 0, (w_int + i) > image_width), 0.0, weight_w[i]
)
weight_h[i] = if_then_else(
Or.join((h_int - 1 + i) < 0, (h_int + i) > image_height), 0.0, weight_h[i]
)
sum_weight_w = sum(weight_w)
sum_weight_h = sum(weight_h)
weight_w = [w / sum_weight_w for w in weight_w]
weight_h = [h / sum_weight_h for h in weight_h]
col0 = cubic_interpolate(pixels[0], weight_w)
col1 = cubic_interpolate(pixels[1], weight_w)
col2 = cubic_interpolate(pixels[2], weight_w)
col3 = cubic_interpolate(pixels[3], weight_w)
value = cubic_interpolate([col0, col1, col2, col3], weight_h)

else:
raise ValueError(
'Unsupported scaling method: {}, candidates: {}'.format(method, ['nearest', 'linear', 'cubic'])
Expand Down Expand Up @@ -134,7 +176,7 @@ class Resize2dOp(Operator):
'tf_half_pixel_for_nn',
'tf_crop_and_resize',
]
supported_rounding_methods = ['round', 'floor', 'ceil']
supported_rounding_methods = ['round', 'round_prefer_floor', 'round_prefer_ceil', 'floor', 'ceil']

def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __eq__(self, other):
def __hash__(self):
return id(self)

def __gt__(self, other):
return LessThan(other, self)

def __ge__(self, other):
return LessEqual(other, self)

Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/functors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def visit_LessThan(self, e: LessThan):
self.visit(e.a)
self.visit(e.b)

def visit_LessEqual(self, e: LessThan):
def visit_LessEqual(self, e: LessEqual):
self.visit(e.a)
self.visit(e.b)

Expand Down
109 changes: 109 additions & 0 deletions tests/graph/operators/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List

import numpy as np
import torch
import torchvision as tv
import pytest

import hidet
from hidet import ops
from hidet.testing import check_binary
from hidet.graph.tensor import array
from hidet.utils.ort_utils import create_ort_session, ort_inference


class TorchResizeModel(torch.nn.Module):
def __init__(self, size, method):
super(TorchResizeModel, self).__init__()
self.transform = tv.transforms.Resize(size, method)

def forward(self, x):
x = self.transform(x)
return x


def ort_resize2d(data: np.ndarray, size: List[int], method: str):
method_map = {
'nearest': tv.transforms.InterpolationMode.NEAREST,
'linear': tv.transforms.InterpolationMode.BILINEAR,
'cubic': tv.transforms.InterpolationMode.BICUBIC,
}
if method not in method_map:
raise NotImplementedError(method)

torch_model = TorchResizeModel(size, method_map[method])
torch_input = torch.from_numpy(data).cuda()
torch.onnx.export(torch_model, torch_input, "torch_resize.onnx")
ort_session = create_ort_session("torch_resize.onnx")
ort_inputs = {'img': hidet.from_torch(torch_input)}
ort_outputs = ort_inference(ort_session, ort_inputs)
ort_output = next(iter(ort_outputs.values()))
return ort_output.numpy()


def torch_resize2d(data: np.ndarray, size: List[int], method: str):
method_map = {
'nearest': tv.transforms.InterpolationMode.NEAREST,
'linear': tv.transforms.InterpolationMode.BILINEAR,
'cubic': tv.transforms.InterpolationMode.BICUBIC,
}
if method not in method_map:
raise NotImplementedError(method)
data_torch = torch.from_numpy(data)
transform = tv.transforms.Resize(size, method_map[method])
output = transform(data_torch).numpy()
return output


# In Pytorch, 'linear' and 'cubic' modes use 'half_pixel' coordinate transformation mode,
# while 'nearest' mode uses 'asymmetric' and 'floor'
@pytest.mark.parametrize(
"n, c, h, w, size, method, coordinate_transformation_mode, rounding_method, roi, cubic_alpha, cubic_exclude, extrapolation_value",
[
[1, 1, 32, 32, [50, 60], 'nearest', 'asymmetric', 'floor', [], -0.75, 0, 0.0], # nearest upsample
[1, 1, 32, 32, [20, 15], 'nearest', 'asymmetric', 'floor', [], -0.75, 0, 0.0], # nearest downsample
[1, 3, 32, 32, [50, 60], 'linear', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # linear upsample
[1, 3, 32, 32, [20, 15], 'linear', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # linear downsample
[1, 3, 32, 32, [50, 60], 'cubic', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # cubic upsample
[1, 3, 32, 32, [20, 15], 'cubic', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # cubic downsample
],
)
def test_resize2d(
n,
c,
h,
w,
size,
method,
coordinate_transformation_mode,
rounding_method,
roi,
cubic_alpha,
cubic_exclude,
extrapolation_value,
):
data_shape = [n, c, h, w]
dtype = 'float32'
data = np.array(np.random.randn(*data_shape)).astype(dtype)
torch_result = torch_resize2d(data, size, method)

hidet_result_cuda = (
ops.resize2d(
array(data).to(device='cuda'),
size,
method,
coordinate_transformation_mode,
rounding_method,
roi,
cubic_alpha,
cubic_exclude,
extrapolation_value,
)
.cpu()
.numpy()
)
np.testing.assert_allclose(actual=hidet_result_cuda, desired=torch_result, atol=2e-5, rtol=2e-5)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 1c6909c

Please sign in to comment.