Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Enable variable side resizing in kornia.resize #628

Merged
merged 9 commits into from Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 36 additions & 23 deletions kornia/geometry/transform/affwarp.py
Expand Up @@ -338,8 +338,24 @@ def shear(tensor: torch.Tensor, shear: torch.Tensor, align_corners: bool = False
return affine(tensor, shear_matrix[..., :2, :3], align_corners=align_corners)


def _edge_to_image_size(
edge_size: int, aspect_ratio: float, edge: str = "short"
) -> Tuple[int, int]:
if edge not in ("short", "long", "vert", "horz"):
raise ValueError(f"edge can be one of 'short', 'long', 'vert', and 'horz'. Got '{edge}'")
if edge == "vert":
return edge_size, int(edge_size * aspect_ratio)
elif edge == "horz":
return int(edge_size / aspect_ratio), edge_size
elif (edge == "short") ^ (aspect_ratio < 1.0):
return edge_size, int(edge_size * aspect_ratio)
else:
return int(edge_size / aspect_ratio), edge_size


def resize(input: torch.Tensor, size: Union[int, Tuple[int, int]],
interpolation: str = 'bilinear', align_corners: bool = False) -> torch.Tensor:
interpolation: str = 'bilinear', align_corners: bool = False,
edge: str = "short") -> torch.Tensor:
r"""Resize the input torch.Tensor to the given size.

See :class:`~kornia.Resize` for details.
Expand All @@ -348,49 +364,46 @@ def resize(input: torch.Tensor, size: Union[int, Tuple[int, int]],
raise TypeError("Input tensor type is not a torch.Tensor. Got {}"
.format(type(input)))

new_size: Tuple[int, int]

input_size = h, w = input.shape[-2:]
if isinstance(size, int):
w, h = input.shape[-2:]
if (w <= h and w == size) or (h <= w and h == size):
return input
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
new_size = (ow, oh)
else:
new_size = size
return torch.nn.functional.interpolate(input, size=new_size, mode=interpolation, align_corners=align_corners)
aspect_ratio = w / h
size = _edge_to_image_size(size, aspect_ratio, edge)

if size == input_size:
return input

return torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)


class Resize(nn.Module):
r"""Resize the input torch.Tensor to the given size.

Args:
size (int, tuple(int, int)): Desired output size. If size is a sequence like (h, w),
output size will be matched to this. If size is an int, smaller edge of the image will
be matched to this number. i.e, if height > width, then image will be rescaled
to (size * height / width, size)
output size will be matched to this. If size is an int, smaller edge of the image will
be matched to this number. i.e, if height > width, then image will be rescaled
to (size * height / width, size)
interpolation (str): algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' |
'bicubic' | 'trilinear' | 'area'. Default: 'bilinear'.
'bicubic' | 'trilinear' | 'area'. Default: 'bilinear'.
align_corners(bool): interpolation flag. Default: False. See
https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail
https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for detail
edge (str): Corresponding edge if ``size`` is an integer. Can be one of ``"short"``, ``"long"``, ``"vert"``,
or ``"horz"``. Defaults to ``"short"``.

Returns:
torch.Tensor: The resized tensor.
"""

def __init__(self, size: Union[int, Tuple[int, int]], interpolation: str = 'bilinear',
align_corners: bool = False) -> None:
align_corners: bool = False, edge: str = "short") -> None:
super(Resize, self).__init__()
self.size: Union[int, Tuple[int, int]] = size
self.interpolation: str = interpolation
self.align_corners: bool = align_corners
self.edge = edge

def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
return resize(input, self.size, self.interpolation, align_corners=self.align_corners)
return resize(input, self.size, self.interpolation, align_corners=self.align_corners, edge=self.edge)


class Rotate(nn.Module):
Expand Down
15 changes: 15 additions & 0 deletions test/geometry/transform/test_affine.py
Expand Up @@ -28,6 +28,21 @@ def test_one_param(self, device):
out = kornia.resize(inp, 10)
assert out.shape == (1, 3, 25, 10)

def test_one_param_long(self, device):
inp = torch.rand(1, 3, 5, 2).to(device)
out = kornia.resize(inp, 10, edge="long")
assert out.shape == (1, 3, 10, 4)

def test_one_param_vert(self, device):
inp = torch.rand(1, 3, 5, 2).to(device)
out = kornia.resize(inp, 10, edge="vert")
assert out.shape == (1, 3, 10, 4)

def test_one_param_horz(self, device):
inp = torch.rand(1, 3, 2, 5).to(device)
out = kornia.resize(inp, 10, edge="horz")
assert out.shape == (1, 3, 4, 10)

def test_gradcheck(self, device):
# test parameters
new_size = 4
Expand Down