Skip to content

Commit

Permalink
feat: in_range filtering (#2895)
Browse files Browse the repository at this point in the history
* initial commit

* add docs

* add tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update docs

* correct typing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tests and docs

* change randn to rand

* Modify docs indentation

* correct docs and remove unused vars

* add return_mask

* Remove shape in doc

* fix docs format

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: edgar <edgar.riba@gmail.com>
  • Loading branch information
3 people committed May 15, 2024
1 parent 9d319ab commit bdd07f3
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 0 deletions.
23 changes: 23 additions & 0 deletions docs/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,29 @@ def main():
sig = f"{fn_name}({', '.join([str(a) for a in args])})"
print(f"Generated image example for {fn_name}. {sig}")

# kornia.filters.in_range
mod = importlib.import_module("kornia.filters")
transforms: dict = {
"in_range": (((0.314, 0.2, 0.2), (0.47, 1.0, 1.0), True), 1),
}
# ITERATE OVER THE TRANSFORMS
for fn_name, (args, num_samples) in transforms.items():
img_hsv = K.color.rgb_to_hsv(img1)
h, s, v = torch.split(img_hsv, split_size_or_sections=1, dim=1)
h = h / (2 * torch.pi)
img_hsv = torch.cat((h, s, v), dim=1)
args_in = (img_hsv, *args)
fn = getattr(mod, fn_name)
mask = fn(*args_in)
filtered = img1 * mask
mask = mask.repeat(1, img1.shape[1], 1, 1)
# save the output image
out = torch.cat([img1[0], mask[0], filtered[0]], dim=-1)
out_np = K.utils.tensor_to_image((out * 255.0).byte())
cv2.imwrite(str(OUTPUT_PATH / f"{fn_name}.png"), out_np)
sig = f"{fn_name}({', '.join([str(a) for a in args])})"
print(f"Generated image example for {fn_name}. {sig}")

# korna.geometry.transform module
mod = importlib.import_module("kornia.geometry.transform")
h, w = img6.shape[-2:]
Expand Down
5 changes: 5 additions & 0 deletions docs/source/filters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ Interactive Demo
Visit the `Kornia edge detector demo on the Hugging Face Spaces
<https://huggingface.co/spaces/kornia/edge_detector>`_.

Segmentation
--------------

.. autofunction:: in_range
.. autoclass:: InRange

Filtering API
-------------
Expand Down
3 changes: 3 additions & 0 deletions kornia/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .filter import filter2d, filter2d_separable, filter3d
from .gaussian import GaussianBlur2d, gaussian_blur2d, gaussian_blur2d_t
from .guided import GuidedBlur, guided_blur
from .in_range import InRange, in_range
from .kernels import (
gaussian,
get_binary_kernel2d,
Expand Down Expand Up @@ -67,6 +68,8 @@
"get_diff_kernel2d",
"gaussian_blur2d",
"guided_blur",
"InRange",
"in_range",
"laplacian",
"laplacian_1d",
"unsharp_mask",
Expand Down
176 changes: 176 additions & 0 deletions kornia/filters/in_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from __future__ import annotations

from typing import Any, Union

import torch

from kornia.core import Module, Tensor
from kornia.core.check import KORNIA_CHECK
from kornia.utils.image import perform_keep_shape_image


@perform_keep_shape_image
def in_range(
input: Tensor,
lower: Union[tuple[Any, ...], Tensor],
upper: Union[tuple[Any, ...], Tensor],
return_mask: bool = False,
) -> Tensor:
r"""Creates a mask indicating whether elements of the input tensor are within the specified range.
.. image:: _static/img/in_range.png
The formula applied for single-channel tensor is:
.. math::
\text{out}(I) = \text{lower}(I) \leq \text{input}(I) \geq \text{upper}(I)
The formula applied for multi-channel tensor is:
.. math::
\text{out}(I) = \bigwedge_{c=0}^{C}
\left( \text{lower}_c(I) \leq \text{input}_c(I) \geq \text{upper}_c(I) \right)
where `C` is the number of channels.
Args:
input: The input tensor to be filtered in the shape of :math:`(*, *, H, W)`.
lower: The lower bounds of the filter (inclusive).
upper: The upper bounds of the filter (inclusive).
return_mask: If is true, the filtered mask is returned, otherwise the filtered input image.
Returns:
A binary mask :math:`(*, 1, H, W)` of input indicating whether elements are within the range
or filtered input image :math:`(*, *, H, W)`.
Raises
ValueError: If the shape of `lower`, `upper`, and `input` image channels do not match.
.. note::
Clarification of `lower` and `upper`:
- If provided as a tuple, it should have the same number of elements as the channels in the input tensor.
This bound is then applied uniformly across all batches.
- When provided as a tensor, it allows for different bounds to be applied to each batch.
The tensor shape should be (B, C, 1, 1), where B is the batch size and C is the number of channels.
- If the tensor has a 1-D shape, same bound will be applied across all batches.
Examples:
>>> rng = torch.manual_seed(1)
>>> input = torch.rand(1, 3, 3, 3)
>>> lower = (0.2, 0.3, 0.4)
>>> upper = (0.8, 0.9, 1.0)
>>> mask = in_range(input, lower, upper, return_mask=True)
>>> mask
tensor([[[[1., 1., 0.],
[0., 0., 0.],
[0., 1., 1.]]]])
>>> mask.shape
torch.Size([1, 1, 3, 3])
Apply different bounds (`lower` and `upper`) for each batch:
>>> rng = torch.manual_seed(1)
>>> input_tensor = torch.rand((2, 3, 3, 3))
>>> input_shape = input_tensor.shape
>>> lower = torch.tensor([[0.2, 0.2, 0.2], [0.2, 0.2, 0.2]]).reshape(input_shape[0], input_shape[1], 1, 1)
>>> upper = torch.tensor([[0.6, 0.6, 0.6], [0.8, 0.8, 0.8]]).reshape(input_shape[0], input_shape[1], 1, 1)
>>> mask = in_range(input_tensor, lower, upper, return_mask=True)
>>> mask
tensor([[[[0., 0., 1.],
[0., 0., 0.],
[1., 0., 0.]]],
<BLANKLINE>
<BLANKLINE>
[[[0., 0., 0.],
[1., 0., 0.],
[0., 0., 1.]]]])
"""
input_shape = input.shape

KORNIA_CHECK(
isinstance(lower, (tuple, Tensor)) and isinstance(upper, (tuple, Tensor)),
"Invalid `lower` and `upper` format. Should be tuple or Tensor.",
)
KORNIA_CHECK(
isinstance(return_mask, bool),
"Invalid `return_mask` format. Should be boolean.",
)

if isinstance(lower, tuple) and isinstance(upper, tuple):
if len(lower) != input_shape[1] or len(upper) != input_shape[1]:
raise ValueError("Shape of `lower`, `upper` and `input` image channels must have same shape.")

lower = (
torch.tensor(lower, device=input.device, dtype=input.dtype)
.reshape(1, -1, 1, 1)
.repeat(input_shape[0], 1, 1, 1)
)
upper = (
torch.tensor(upper, device=input.device, dtype=input.dtype)
.reshape(1, -1, 1, 1)
.repeat(input_shape[0], 1, 1, 1)
)

elif isinstance(lower, Tensor) and isinstance(upper, Tensor):
valid_tensor_shape = (input_shape[0], input_shape[1], 1, 1)
if valid_tensor_shape not in (lower.shape, upper.shape):
raise ValueError(
"`lower` and `upper` bounds as Tensors must have compatible shapes with the input (B, C, 1, 1)."
)
lower = lower.to(input)
upper = upper.to(input)

# Apply lower and upper bounds. Combine masks with logical_and.
mask = torch.logical_and(input >= lower, input <= upper)
mask = mask.all(dim=(1), keepdim=True).to(input.dtype)

if return_mask:
return mask

return input * mask


class InRange(Module):
r"""Creates a module for applying lower and upper bounds to input tensors.
Args:
input: The input tensor to be filtered.
lower: The lower bounds of the filter (inclusive).
upper: The upper bounds of the filter (inclusive).
return_mask: If is true, the filtered mask is returned, otherwise the filtered input image.
Returns:
A binary mask :math:`(*, 1, H, W)` of input indicating whether elements are within the range
or filtered input image :math:`(*, *, H, W)`.
.. note::
View complete documentation in :func:`kornia.filters.in_range`.
Examples:
>>> rng = torch.manual_seed(1)
>>> input = torch.rand(1, 3, 3, 3)
>>> lower = (0.2, 0.3, 0.4)
>>> upper = (0.8, 0.9, 1.0)
>>> mask = InRange(lower, upper, return_mask=True)(input)
>>> mask
tensor([[[[1., 1., 0.],
[0., 0., 0.],
[0., 1., 1.]]]])
"""

def __init__(
self,
lower: Union[tuple[Any, ...], Tensor],
upper: Union[tuple[Any, ...], Tensor],
return_mask: bool = False,
) -> None:
super().__init__()
self.lower = lower
self.upper = upper
self.return_mask = return_mask

def forward(self, input: Tensor) -> Tensor:
return in_range(input, self.lower, self.upper, self.return_mask)
124 changes: 124 additions & 0 deletions tests/filters/test_in_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import re

import pytest
import torch

from kornia.filters import InRange, in_range

from testing.base import BaseTester, assert_close


def test_in_range(device, dtype):
torch.manual_seed(1)
input_tensor = torch.rand(1, 3, 3, 3, device=device)
input_tensor = input_tensor.to(dtype=dtype)
expected = torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]], device=device, dtype=dtype)
lower = (0.2, 0.3, 0.4)
upper = (0.8, 0.9, 1.0)
result = in_range(input_tensor, lower, upper, return_mask=True)

assert_close(result, expected, atol=1e-4, rtol=1e-4)


class TestInRange(BaseTester):
def _get_expected(self, device, dtype):
return torch.tensor(
[[[[1.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]],
device=device,
dtype=dtype,
)

def test_smoke(self, device, dtype):
torch.manual_seed(1)
input_tensor = torch.rand(1, 3, 3, 3, device=device)
input_tensor = input_tensor.to(dtype=dtype)
expected = self._get_expected(device=device, dtype=dtype)
res = InRange(lower=(0.2, 0.3, 0.4), upper=(0.8, 0.9, 1.0), return_mask=True)(input_tensor)
assert expected.shape == res.shape
self.assert_close(res, expected, rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize(
"input_shape, lower, upper",
[
((1, 3, 3, 3), (0.2, 0.2, 0.2), (0.6, 0.6, 0.6)),
((2, 3, 3, 3), (0.2, 0.2, 0.2), (0.6, 0.6, 0.6)),
((5, 5, 3, 3), (0.2, 0.2, 0.2, 0.2, 0.2), (0.6, 0.6, 0.6, 0.6, 0.6)),
((3, 3), (0.2,), (0.6,)),
((2, 3, 3), (0.2, 0.2), (0.6, 0.6)),
],
)
def test_cardinality(self, input_shape, lower, upper, device, dtype):
input_tensor = torch.rand(input_shape, device=device, dtype=dtype)
res = InRange(lower=lower, upper=upper, return_mask=True)(input_tensor)

if len(input_tensor.shape) == 2:
assert res.shape == (res.shape[-2], res.shape[-1])
elif len(input_tensor.shape) == 3:
assert res.shape == (1, res.shape[-2], res.shape[-1])
else:
assert res.shape == (res.shape[0], 1, res.shape[-2], res.shape[-1])

def test_exception(self, device, dtype):
input_tensor = torch.rand(1, 3, 3, 3, device=device, dtype=dtype)
with pytest.raises(Exception, match="Invalid `lower` and `upper` format. Should be tuple or Tensor."):
InRange(lower=3, upper=3)(input_tensor)

with pytest.raises(Exception, match="Invalid `lower` and `upper` format. Should be tuple or Tensor."):
InRange(lower=[0.2, 0.2], upper=[0.2, 0.2])(input_tensor)

with pytest.raises(Exception, match="Invalid `lower` and `upper` format. Should be tuple or Tensor."):
InRange(lower=(0.2), upper=(0.2))(input_tensor)

with pytest.raises(
ValueError, match="Shape of `lower`, `upper` and `input` image channels must have same shape."
):
InRange(lower=(0.2,), upper=(0.2,))(input_tensor)

with pytest.raises(
ValueError,
match=re.escape(
"`lower` and `upper` bounds as Tensors must have compatible shapes with the input (B, C, 1, 1)."
),
):
lower = torch.tensor([0.2, 0.2, 0.2])
upper = torch.tensor([0.6, 0.6, 0.6])
InRange(lower=lower, upper=upper)(input_tensor)

with pytest.raises(Exception, match="Invalid `return_mask` format. Should be boolean."):
lower = torch.tensor([0.2, 0.2, 0.2])
upper = torch.tensor([0.6, 0.6, 0.6])
InRange(lower=lower, upper=upper, return_mask=2)(input_tensor)

def test_noncontiguous(self, device, dtype):
batch_size = 3
inp = torch.rand(1, 3, 5, 5, device=device, dtype=dtype).expand(batch_size, -1, -1, -1)
actual = InRange((0.2, 0.2, 0.2), (0.6, 0.6, 0.6), return_mask=True)(inp)
assert actual.is_contiguous()

def test_gradcheck(self, device):
batch_size, channels, height, width = 1, 3, 5, 5
img = torch.rand(batch_size, channels, height, width, device=device, dtype=torch.float64)
self.gradcheck(in_range, (img, (0.2, 0.2, 0.2), (0.6, 0.6, 0.6), True))

@pytest.mark.parametrize(
"input_shape, lower, upper",
[
((1, 3, 3, 3), (0.2, 0.2, 0.2), (0.6, 0.6, 0.6)),
((2, 3, 3, 3), (0.2, 0.2, 0.2), (0.6, 0.6, 0.6)),
((3, 3), (0.2,), (0.6,)),
],
)
def test_module(self, input_shape, lower, upper, device, dtype):
img = torch.rand(input_shape, device=device, dtype=dtype)
op = in_range
op_module = InRange(lower=lower, upper=upper, return_mask=True)
actual = op_module(img)
expected = op(img, lower, upper, True)
self.assert_close(actual, expected)

@pytest.mark.parametrize("batch_size", [1, 2])
def test_dynamo(self, batch_size, device, dtype, torch_optimizer):
inpt = torch.rand(batch_size, 3, 5, 5, device=device, dtype=dtype)
op = InRange(lower=(0.2, 0.2, 0.2), upper=(0.6, 0.6, 0.6), return_mask=True)
op_optimized = torch_optimizer(op)
self.assert_close(op(inpt), op_optimized(inpt))

0 comments on commit bdd07f3

Please sign in to comment.