Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add Support for Images not Div by 16 (Diff. JPEG) #2865

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 29 additions & 6 deletions kornia/enhance/jpeg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import math
from typing import Tuple

import torch
import torch.nn.functional as F

from kornia.color import rgb_to_ycbcr, ycbcr_to_rgb
from kornia.constants import pi
Expand Down Expand Up @@ -344,6 +348,27 @@ def _jpeg_decode(
return rgb_decoded


def _perform_padding(image: Tensor) -> Tuple[Tensor, int, int]:
"""Pads a given image to be dividable by 16.

Args:
image: Image of the shape :math:`(*, 3, H, W)`.

Returns:
image_padded: Padded image of the shape :math:`(*, 3, H_{new}, W_{new})`.
h_pad: Padded pixels along the horizontal axis.
w_pad: Padded pixels along the vertical axis.
"""
# Get spatial dimensions of the image
H, W = image.shape[-2:]
# Compute horizontal and vertical padding
h_pad: int = math.ceil(H / 16) * 16 - H
w_pad: int = math.ceil(W / 16) * 16 - W
# Perform padding (we follow JPEG and pad only the bottom and right side of the image)
image_padded: Tensor = F.pad(image, (0, w_pad, 0, h_pad), "replicate")
return image_padded, h_pad, w_pad


@perform_keep_shape_image
def jpeg_codec_differentiable(
image_rgb: Tensor,
Expand Down Expand Up @@ -437,10 +462,6 @@ def jpeg_codec_differentiable(
KORNIA_CHECK_IS_TENSOR(quantization_table_c)
# Check shape of inputs
KORNIA_CHECK_SHAPE(image_rgb, ["*", "3", "H", "W"])
KORNIA_CHECK(
(image_rgb.shape[-1] % 16 == 0) and (image_rgb.shape[-2] % 16 == 0),
f"image dimension must be divisible by 16. Got the shape {image_rgb.shape}.",
)
KORNIA_CHECK_SHAPE(jpeg_quality, ["B"])
# Add batch dimension to quantization tables if needed
if quantization_table_y.ndim == 2:
Expand All @@ -456,6 +477,8 @@ def jpeg_codec_differentiable(
f"JPEG quality is out of range. Expected range is [0, 100], "
f"got [{jpeg_quality.amin().item()}, {jpeg_quality.amax().item()}]. Consider clipping jpeg_quality.",
)
# Pad the image to a shape dividable by 16
image_rgb, h_pad, w_pad = _perform_padding(image_rgb)
# Get height and shape
H, W = image_rgb.shape[-2:]
# Check matching batch dimensions
Expand Down Expand Up @@ -499,8 +522,8 @@ def jpeg_codec_differentiable(
)
# Clip coded image
image_rgb_jpeg = differentiable_clipping(input=image_rgb_jpeg, min_val=0.0, max_val=255.0)
# Back to original shape
# image_rgb_jpeg = image_rgb_jpeg.view(original_shape)
# Crop the image again to the original shape
image_rgb_jpeg = image_rgb_jpeg[..., : H - h_pad, : W - w_pad]
return image_rgb_jpeg


Expand Down
16 changes: 9 additions & 7 deletions tests/enhance/test_jpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def test_smoke(self, device, dtype) -> None:
assert img_jpeg is not None
assert img_jpeg.shape == img.shape

def test_smoke_not_div_by_16(self, device, dtype) -> None:
"""This test standard usage."""
B, H, W = 2, 33, 33
img = torch.rand(B, 3, H, W, device=device, dtype=dtype)
jpeg_quality = torch.randint(low=0, high=100, size=(B,), device=device, dtype=dtype)
img_jpeg = kornia.enhance.jpeg_codec_differentiable(img, jpeg_quality)
assert img_jpeg is not None
assert img_jpeg.shape == img.shape

def test_multi_batch(self, device, dtype) -> None:
"""Here we test two batch dimensions."""
B, H, W = 4, 32, 32
Expand Down Expand Up @@ -81,13 +90,6 @@ def test_exception(self, device, dtype) -> None:
kornia.enhance.jpeg_codec_differentiable(img, jpeg_quality)
assert "shape must be [" in str(errinfo)

with pytest.raises(Exception) as errinfo:
B, H, W = 2, 31, 31
img = torch.rand(B, 3, H, W, device=device, dtype=dtype)
jpeg_quality = torch.randint(low=0, high=100, size=(B,), device=device, dtype=dtype)
kornia.enhance.jpeg_codec_differentiable(img, jpeg_quality)
assert "divisible" in str(errinfo)

with pytest.raises(Exception) as errinfo:
B, H, W = 4, 32, 32
img = torch.rand(B, 3, H, W, device=device, dtype=dtype)
Expand Down