Skip to content

Commit

Permalink
[Fix] Add Support for Images not Div by 16 (Diff. JPEG) (#2865)
Browse files Browse the repository at this point in the history
* Add support for images not div by 16 (diff JPEG)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update kornia/enhance/jpeg.py

Co-authored-by: Jian Shi <sj8716643@126.com>

* Remove if statements (both in pad and crop)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jian Shi <sj8716643@126.com>
  • Loading branch information
3 people committed Apr 3, 2024
1 parent 9a5337c commit 8f0f0c4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
35 changes: 29 additions & 6 deletions kornia/enhance/jpeg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import math
from typing import Tuple

import torch
import torch.nn.functional as F

from kornia.color import rgb_to_ycbcr, ycbcr_to_rgb
from kornia.constants import pi
Expand Down Expand Up @@ -344,6 +348,27 @@ def _jpeg_decode(
return rgb_decoded


def _perform_padding(image: Tensor) -> Tuple[Tensor, int, int]:
"""Pads a given image to be dividable by 16.
Args:
image: Image of the shape :math:`(*, 3, H, W)`.
Returns:
image_padded: Padded image of the shape :math:`(*, 3, H_{new}, W_{new})`.
h_pad: Padded pixels along the horizontal axis.
w_pad: Padded pixels along the vertical axis.
"""
# Get spatial dimensions of the image
H, W = image.shape[-2:]
# Compute horizontal and vertical padding
h_pad: int = math.ceil(H / 16) * 16 - H
w_pad: int = math.ceil(W / 16) * 16 - W
# Perform padding (we follow JPEG and pad only the bottom and right side of the image)
image_padded: Tensor = F.pad(image, (0, w_pad, 0, h_pad), "replicate")
return image_padded, h_pad, w_pad


@perform_keep_shape_image
def jpeg_codec_differentiable(
image_rgb: Tensor,
Expand Down Expand Up @@ -437,10 +462,6 @@ def jpeg_codec_differentiable(
KORNIA_CHECK_IS_TENSOR(quantization_table_c)
# Check shape of inputs
KORNIA_CHECK_SHAPE(image_rgb, ["*", "3", "H", "W"])
KORNIA_CHECK(
(image_rgb.shape[-1] % 16 == 0) and (image_rgb.shape[-2] % 16 == 0),
f"image dimension must be divisible by 16. Got the shape {image_rgb.shape}.",
)
KORNIA_CHECK_SHAPE(jpeg_quality, ["B"])
# Add batch dimension to quantization tables if needed
if quantization_table_y.ndim == 2:
Expand All @@ -456,6 +477,8 @@ def jpeg_codec_differentiable(
f"JPEG quality is out of range. Expected range is [0, 100], "
f"got [{jpeg_quality.amin().item()}, {jpeg_quality.amax().item()}]. Consider clipping jpeg_quality.",
)
# Pad the image to a shape dividable by 16
image_rgb, h_pad, w_pad = _perform_padding(image_rgb)
# Get height and shape
H, W = image_rgb.shape[-2:]
# Check matching batch dimensions
Expand Down Expand Up @@ -499,8 +522,8 @@ def jpeg_codec_differentiable(
)
# Clip coded image
image_rgb_jpeg = differentiable_clipping(input=image_rgb_jpeg, min_val=0.0, max_val=255.0)
# Back to original shape
# image_rgb_jpeg = image_rgb_jpeg.view(original_shape)
# Crop the image again to the original shape
image_rgb_jpeg = image_rgb_jpeg[..., : H - h_pad, : W - w_pad]
return image_rgb_jpeg


Expand Down
16 changes: 9 additions & 7 deletions tests/enhance/test_jpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def test_smoke(self, device, dtype) -> None:
assert img_jpeg is not None
assert img_jpeg.shape == img.shape

def test_smoke_not_div_by_16(self, device, dtype) -> None:
"""This test standard usage."""
B, H, W = 2, 33, 33
img = torch.rand(B, 3, H, W, device=device, dtype=dtype)
jpeg_quality = torch.randint(low=0, high=100, size=(B,), device=device, dtype=dtype)
img_jpeg = kornia.enhance.jpeg_codec_differentiable(img, jpeg_quality)
assert img_jpeg is not None
assert img_jpeg.shape == img.shape

def test_multi_batch(self, device, dtype) -> None:
"""Here we test two batch dimensions."""
B, H, W = 4, 32, 32
Expand Down Expand Up @@ -81,13 +90,6 @@ def test_exception(self, device, dtype) -> None:
kornia.enhance.jpeg_codec_differentiable(img, jpeg_quality)
assert "shape must be [" in str(errinfo)

with pytest.raises(Exception) as errinfo:
B, H, W = 2, 31, 31
img = torch.rand(B, 3, H, W, device=device, dtype=dtype)
jpeg_quality = torch.randint(low=0, high=100, size=(B,), device=device, dtype=dtype)
kornia.enhance.jpeg_codec_differentiable(img, jpeg_quality)
assert "divisible" in str(errinfo)

with pytest.raises(Exception) as errinfo:
B, H, W = 4, 32, 32
img = torch.rand(B, 3, H, W, device=device, dtype=dtype)
Expand Down

0 comments on commit 8f0f0c4

Please sign in to comment.