forked from kornia/kornia
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added unsharp mask filtering (kornia#1004)
* Added functionality for unsharp mask * Added functionality for Unsharp Mask * Implemented the changes to Unsharp Mask functionality * fix linter and jit test * add unsharp to docs * update python version in gh actions format Co-authored-by: AnimeshMaheshwari22 <animesh.m2202@gmail.com> Co-authored-by: AnimeshMaheshwari22 <45392539+AnimeshMaheshwari22@users.noreply.github.com>
- Loading branch information
1 parent
f725260
commit d57a444
Showing
7 changed files
with
143 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,6 +90,7 @@ | |
motion_blur, | ||
filter2D, | ||
filter3D, | ||
unsharp_mask, | ||
) | ||
from kornia.losses import ( | ||
ssim, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import kornia | ||
from kornia.filters import gaussian_blur2d | ||
|
||
|
||
def unsharp_mask( | ||
input: torch.Tensor, | ||
kernel_size: Tuple[int, int], | ||
sigma: Tuple[float, float], | ||
border_type: str = 'reflect') -> torch.Tensor: | ||
r"""Creates an operator that blurs a tensor using the existing Gaussian filter available with the Kornia library. | ||
Arguments: | ||
input (torch.Tensor): the input tensor with shape :math:`(B,C,H,W)`. | ||
kernel_size (Tuple[int, int]): the size of the kernel. | ||
sigma (Tuple[float, float]): the standard deviation of the kernel. | ||
border_type (str): the padding mode to be applied before convolving. | ||
The expected modes are: ``'constant'``, ``'reflect'``, | ||
``'replicate'`` or ``'circular'``. Default: ``'reflect'``. | ||
Returns: | ||
torch.Tensor: the blurred tensor with shape :math:`(B,C,H,W)`. | ||
Examples: | ||
>>> input = torch.rand(2, 4, 5, 5) | ||
>>> output = unsharp_mask(input, (3, 3), (1.5, 1.5)) | ||
>>> output.shape | ||
torch.Size([2, 4, 5, 5]) | ||
""" | ||
data_blur: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) | ||
data_sharpened: torch.Tensor = input + (input - data_blur) | ||
return data_sharpened | ||
|
||
|
||
class UnsharpMask(nn.Module): | ||
r"""Creates an operator that sharpens image using the existing Gaussian filter available with the Kornia library.. | ||
Arguments: | ||
kernel_size (Tuple[int, int]): the size of the kernel. | ||
sigma (Tuple[float, float]): the standard deviation of the kernel. | ||
border_type (str): the padding mode to be applied before convolving. | ||
The expected modes are: ``'constant'``, ``'reflect'``, | ||
``'replicate'`` or ``'circular'``. Default: ``'reflect'``. | ||
Returns: | ||
Tensor: the sharpened tensor with shape :math:`(B,C,H,W)`. | ||
Shape: | ||
- Input: :math:`(B, C, H, W)` | ||
- Output: :math:`(B, C, H, W)` | ||
Examples: | ||
>>> input = torch.rand(2, 4, 5, 5) | ||
>>> sharpen = UnsharpMask((3, 3), (1.5, 1.5)) | ||
>>> output = sharpen(input) | ||
>>> output.shape | ||
torch.Size([2, 4, 5, 5]) | ||
""" | ||
|
||
def __init__(self, kernel_size: Tuple[int, int], | ||
sigma: Tuple[float, float], | ||
border_type: str = 'reflect') -> None: | ||
super(UnsharpMask, self).__init__() | ||
self.kernel_size: Tuple[int, int] = kernel_size | ||
self.sigma: Tuple[float, float] = sigma | ||
self.border_type = border_type | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return unsharp_mask(input, self.kernel_size, self.sigma, self.border_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
|
||
import kornia | ||
import kornia.testing as utils # test utils | ||
|
||
import torch | ||
from torch.autograd import gradcheck | ||
from torch.testing import assert_allclose | ||
|
||
|
||
class Testunsharp: | ||
@pytest.mark.parametrize("batch_shape", [(1, 4, 8, 15), (2, 3, 11, 7)]) | ||
def test_cardinality(self, batch_shape, device, dtype): | ||
kernel_size = (5, 7) | ||
sigma = (1.5, 2.1) | ||
|
||
input = torch.rand(batch_shape, device=device, dtype=dtype) | ||
actual = kornia.filters.unsharp_mask(input, kernel_size, sigma, "replicate") | ||
assert actual.shape == batch_shape | ||
|
||
def test_noncontiguous(self, device, dtype): | ||
batch_size = 3 | ||
input = torch.rand(3, 5, 5, device=device, dtype=dtype).expand(batch_size, -1, -1, -1) | ||
|
||
kernel_size = (3, 3) | ||
sigma = (1.5, 2.1) | ||
actual = kornia.filters.unsharp_mask(input, kernel_size, sigma, "replicate") | ||
assert_allclose(actual, actual) | ||
|
||
def test_gradcheck(self, device, dtype): | ||
# test parameters | ||
batch_shape = (1, 3, 5, 5) | ||
kernel_size = (3, 3) | ||
sigma = (1.5, 2.1) | ||
|
||
# evaluate function gradient | ||
input = torch.rand(batch_shape, device=device, dtype=dtype) | ||
input = utils.tensor_to_gradcheck_var(input) # to var | ||
assert gradcheck( | ||
kornia.filters.unsharp_mask, | ||
(input, kernel_size, sigma, "replicate"), | ||
raise_exception=True, | ||
) | ||
|
||
def test_jit(self, device, dtype): | ||
op = kornia.filters.unsharp_mask | ||
op_script = torch.jit.script(op) | ||
params = [(3, 3), (1.5, 1.5)] | ||
|
||
img = torch.ones(1, 3, 5, 5, device=device, dtype=dtype) | ||
assert_allclose(op(img, *params), op_script(img, *params)) | ||
|
||
def test_module(self, device, dtype): | ||
params = [(3, 3), (1.5, 1.5)] | ||
op = kornia.filters.unsharp_mask | ||
op_module = kornia.filters.UnsharpMask(*params) | ||
|
||
img = torch.ones(1, 3, 5, 5, device=device, dtype=dtype) | ||
assert_allclose(op(img, *params), op_module(img)) |