Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add cubic interpolation to Resize Operator #22

Merged
merged 5 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/hidet/graph/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,14 @@ def run(self, inputs: List[Tensor]) -> List[Tensor]:
if roi is not None:
roi = self.tensor2list(roi)
target_size = None
if scales is not None:
if scales is not None and scales.num_elements > 0:
scales = self.tensor2list(scales)
assert len(x.shape) == len(scales)
target_size = [int(a * b) for a, b in zip(x.shape, scales)]
if sizes is not None:
elif sizes is not None and sizes.num_elements > 0:
sizes = self.tensor2list(sizes)
target_size = [int(v) for v in sizes]
if target_size is None:
else:
raise ValueError('Resize operator in onnx must give either scales or sizes.')
if len(x.shape) == 4:
if not (target_size[0] == x.shape[0] and target_size[1] == x.shape[1]):
Expand Down
56 changes: 49 additions & 7 deletions python/hidet/graph/ops/definitions/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, List

from hidet.ir.expr import Expr, if_then_else, convert, cast, And
from hidet.ir.expr import Expr, if_then_else, convert, cast, And, Or
from hidet.ir import primitives as prim
from .utils import Task, Operator, Tensor, TensorNode, compute, input_like

Expand All @@ -14,7 +14,7 @@ def get_origin_index(x: Expr, image_width: int, target_width: int, coordinate_tr
'half_pixel': lambda x: (x + 0.5) * scale - 0.5,
'align_corners': lambda x: x * ((image_width - 1) / (target_width - 1)),
'asymmetric': lambda x: x * scale,
'pytorch_half_pixel': lambda x: (x + 0.5) * scale if target_width > 1 else convert(0.0),
'pytorch_half_pixel': lambda x: (x + 0.5) * scale - 0.5 if target_width > 1 else convert(0.0),
'tf_half_pixel_for_nn': lambda x: (x + 0.5) * scale,
}
if coordinate_transformation_mode not in func_map:
Expand All @@ -28,7 +28,7 @@ def get_origin_index(x: Expr, image_width: int, target_width: int, coordinate_tr

def get_closest_index(x: Expr, rounding_method: str) -> Expr:
func_map = {
'rounding_method': lambda x: cast(prim.round(x), 'int32'),
'round': lambda x: cast(prim.round(x), 'int32'),
'round_prefer_floor': lambda x: cast(prim.ceil(x - 0.5), 'int32'),
'round_prefer_ceil': lambda x: cast(prim.floor(x + 0.5), 'int32'),
'floor': lambda x: cast(prim.floor(x + 1e-5), 'int32'), # add epsilon (1e-5) to prevent gpu rounding error
Expand All @@ -41,15 +41,30 @@ def get_closest_index(x: Expr, rounding_method: str) -> Expr:

def get_2d_pixel(data: TensorNode, n, c, h, w) -> Expr:
height, width = data.const_shape()[2:]
h = prim.max(0, prim.min(height, h))
w = prim.max(0, prim.min(width, w))
h = prim.max(0, prim.min(height - 1, h))
w = prim.max(0, prim.min(width - 1, w))
return data[n, c, h, w]


def linear_interpolate(a, b, ratio):
return a * (1.0 - ratio) + b * ratio


def get_cubic_weights(s, a):
# See equations (4)-(6) in https://ieeexplore.ieee.org/document/1163711
s2 = s * s
s3 = s * s * s
w1 = a * (s3 - 2 * s2 + s)
w2 = (a + 2) * s3 - (3 + a) * s2 + 1
w3 = -(a + 2) * s3 + (3 + 2 * a) * s2 - a * s
w4 = -a * s3 + a * s2
return [w1, w2, w3, w4]


def cubic_interpolate(inputs, weights):
return sum(inputs_i * weights_i for inputs_i, weights_i in zip(inputs, weights))


def resize2d_nchw_compute(
data: TensorNode,
size: List[int],
Expand All @@ -63,6 +78,8 @@ def resize2d_nchw_compute(
): # pylint: disable=unused-argument
image_size = data.const_shape()[2:]
target_size = size
image_height = image_size[0]
image_width = image_size[1]

def fmap(n, c, h, w):
h = get_origin_index(h, image_size[0], target_size[0], coordinate_transformation_mode)
Expand All @@ -81,7 +98,32 @@ def fmap(n, c, h, w):
bottom = linear_interpolate(*pixels[1], w_ratio)
value = linear_interpolate(top, bottom, h_ratio)
elif method == 'cubic':
raise NotImplementedError(method)
h_int = cast(prim.floor(h), 'int32')
w_int = cast(prim.floor(w), 'int32')
h_ratio = h - prim.floor(h)
w_ratio = w - prim.floor(w)
pixels = [[get_2d_pixel(data, n, c, h_int + i - 1, w_int + j - 1) for j in range(4)] for i in range(4)]

weight_w = get_cubic_weights(w_ratio, cubic_alpha)
weight_h = get_cubic_weights(h_ratio, cubic_alpha)
if cubic_exclude:
for i in range(4):
weight_w[i] = if_then_else(
Or.join((w_int - 1 + i) < 0, (w_int + i) > image_width), 0.0, weight_w[i]
)
weight_h[i] = if_then_else(
Or.join((h_int - 1 + i) < 0, (h_int + i) > image_height), 0.0, weight_h[i]
)
sum_weight_w = sum(weight_w)
sum_weight_h = sum(weight_h)
weight_w = [w / sum_weight_w for w in weight_w]
weight_h = [h / sum_weight_h for h in weight_h]
col0 = cubic_interpolate(pixels[0], weight_w)
col1 = cubic_interpolate(pixels[1], weight_w)
col2 = cubic_interpolate(pixels[2], weight_w)
col3 = cubic_interpolate(pixels[3], weight_w)
value = cubic_interpolate([col0, col1, col2, col3], weight_h)

else:
raise ValueError(
'Unsupported scaling method: {}, candidates: {}'.format(method, ['nearest', 'linear', 'cubic'])
Expand Down Expand Up @@ -134,7 +176,7 @@ class Resize2dOp(Operator):
'tf_half_pixel_for_nn',
'tf_crop_and_resize',
]
supported_rounding_methods = ['round', 'floor', 'ceil']
supported_rounding_methods = ['round', 'round_prefer_floor', 'round_prefer_ceil', 'floor', 'ceil']

def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __eq__(self, other):
def __hash__(self):
return id(self)

def __gt__(self, other):
return LessThan(other, self)

def __ge__(self, other):
return LessEqual(other, self)

Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/functors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def visit_LessThan(self, e: LessThan):
self.visit(e.a)
self.visit(e.b)

def visit_LessEqual(self, e: LessThan):
def visit_LessEqual(self, e: LessEqual):
self.visit(e.a)
self.visit(e.b)

Expand Down
109 changes: 109 additions & 0 deletions tests/graph/operators/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List

import numpy as np
import torch
import torchvision as tv
import pytest

import hidet
from hidet import ops
from hidet.testing import check_binary
from hidet.graph.tensor import array
from hidet.utils.ort_utils import create_ort_session, ort_inference


class TorchResizeModel(torch.nn.Module):
def __init__(self, size, method):
super(TorchResizeModel, self).__init__()
self.transform = tv.transforms.Resize(size, method)

def forward(self, x):
x = self.transform(x)
return x


def ort_resize2d(data: np.ndarray, size: List[int], method: str):
method_map = {
'nearest': tv.transforms.InterpolationMode.NEAREST,
'linear': tv.transforms.InterpolationMode.BILINEAR,
'cubic': tv.transforms.InterpolationMode.BICUBIC,
}
if method not in method_map:
raise NotImplementedError(method)

torch_model = TorchResizeModel(size, method_map[method])
torch_input = torch.from_numpy(data).cuda()
torch.onnx.export(torch_model, torch_input, "torch_resize.onnx")
ort_session = create_ort_session("torch_resize.onnx")
ort_inputs = {'img': hidet.from_torch(torch_input)}
ort_outputs = ort_inference(ort_session, ort_inputs)
ort_output = next(iter(ort_outputs.values()))
return ort_output.numpy()


def torch_resize2d(data: np.ndarray, size: List[int], method: str):
method_map = {
'nearest': tv.transforms.InterpolationMode.NEAREST,
'linear': tv.transforms.InterpolationMode.BILINEAR,
'cubic': tv.transforms.InterpolationMode.BICUBIC,
}
if method not in method_map:
raise NotImplementedError(method)
data_torch = torch.from_numpy(data)
transform = tv.transforms.Resize(size, method_map[method])
output = transform(data_torch).numpy()
return output


# In Pytorch, 'linear' and 'cubic' modes use 'half_pixel' coordinate transformation mode,
# while 'nearest' mode uses 'asymmetric' and 'floor'
@pytest.mark.parametrize(
"n, c, h, w, size, method, coordinate_transformation_mode, rounding_method, roi, cubic_alpha, cubic_exclude, extrapolation_value",
[
[1, 1, 32, 32, [50, 60], 'nearest', 'asymmetric', 'floor', [], -0.75, 0, 0.0], # nearest upsample
[1, 1, 32, 32, [20, 15], 'nearest', 'asymmetric', 'floor', [], -0.75, 0, 0.0], # nearest downsample
[1, 3, 32, 32, [50, 60], 'linear', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # linear upsample
[1, 3, 32, 32, [20, 15], 'linear', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # linear downsample
[1, 3, 32, 32, [50, 60], 'cubic', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # cubic upsample
[1, 3, 32, 32, [20, 15], 'cubic', 'half_pixel', 'floor', [], -0.75, 0, 0.0], # cubic downsample
],
)
def test_resize2d(
n,
c,
h,
w,
size,
method,
coordinate_transformation_mode,
rounding_method,
roi,
cubic_alpha,
cubic_exclude,
extrapolation_value,
):
data_shape = [n, c, h, w]
dtype = 'float32'
data = np.array(np.random.randn(*data_shape)).astype(dtype)
torch_result = torch_resize2d(data, size, method)

hidet_result_cuda = (
ops.resize2d(
array(data).to(device='cuda'),
size,
method,
coordinate_transformation_mode,
rounding_method,
roi,
cubic_alpha,
cubic_exclude,
extrapolation_value,
)
.cpu()
.numpy()
)
np.testing.assert_allclose(actual=hidet_result_cuda, desired=torch_result, atol=2e-5, rtol=2e-5)


if __name__ == '__main__':
pytest.main([__file__])