Skip to content

Commit

Permalink
implement get_gaussian_kernel, get_gaussian_kernel2d, gaussian_blur a…
Browse files Browse the repository at this point in the history
…nd GaussianBlur
  • Loading branch information
edgarriba committed Jan 9, 2019
1 parent 06bdd03 commit 0ff6f5e
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 0 deletions.
45 changes: 45 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

import torch
import torchgeometry.image as image
from torch.autograd import gradcheck

import utils
from common import TEST_DEVICES

@pytest.mark.parametrize("window_size", [5, 11, 15])
@pytest.mark.parametrize("sigma", [1.5, 5.0, 21.0])
def test_get_gaussian_kernel(window_size, sigma):
kernel = image.get_gaussian_kernel(window_size, sigma)
assert kernel.shape == (window_size,)
assert kernel.sum().item() == pytest.approx(1.0)

@pytest.mark.parametrize("ksize_x", [5, 11])
@pytest.mark.parametrize("ksize_y", [3, 7])
@pytest.mark.parametrize("sigma", [1.5, 21.0])
def test_get_gaussian_kernel2d(ksize_x, ksize_y, sigma):
kernel = image.get_gaussian_kernel2d(
(ksize_x, ksize_y), (sigma, sigma))
assert kernel.shape == (ksize_x, ksize_y)
assert kernel.sum().item() == pytest.approx(1.0)

@pytest.mark.parametrize("ksize_x", [5, 11])
@pytest.mark.parametrize("ksize_y", [3, 7])
@pytest.mark.parametrize("sigma", [1.5, 21.0])
@pytest.mark.parametrize("device_type", TEST_DEVICES)
@pytest.mark.parametrize("batch_shape",
[(1, 1, 10, 16), (1, 4, 8, 15), (2, 3, 11, 7)])
def test_gaussian_blur(batch_shape, device_type, ksize_x, ksize_y, sigma):
kernel_size = (ksize_x, ksize_y)
sigma = (sigma, sigma)

input = torch.rand(batch_shape).to(torch.device(device_type))
gauss = image.GaussianBlur(kernel_size, sigma)
assert gauss(input).shape == batch_shape

# functional
assert image.gaussian_blur(input, kernel_size, sigma).shape == batch_shape

# evaluate function gradient
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(gauss, (input,), raise_exception=True)
2 changes: 2 additions & 0 deletions torchgeometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
from .utils import *
from .imgwarp import *

from torchgeometry import image


__version__ = '0.1.2rc1' # the current version of the lib
2 changes: 2 additions & 0 deletions torchgeometry/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .gaussian import get_gaussian_kernel, get_gaussian_kernel2d
from .gaussian import GaussianBlur, gaussian_blur
168 changes: 168 additions & 0 deletions torchgeometry/image/gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Tuple

import torch
import torch.nn as nn
from torch.nn.functional import conv2d

def gaussian(window_size, sigma):
def gauss_fcn(x):
return -(x - window_size // 2)**2 / float(2 * sigma**2)
gauss = torch.stack(
[torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
return gauss / gauss.sum()

def get_gaussian_kernel(ksize: int, sigma: float) -> torch.Tensor:
r"""Function that returns Gaussian filter coefficients.
Args:
ksize (int): filter size. It should be odd and positive.
sigma (float): gaussian standard deviation.
Returns:
Tensor: 1D tensor with gaussian filter coefficients.
Shape:
- Output: :math:`(ksize,)`
Examples::
>>> tgm.image.get_gaussian_kernel(3, 2.5)
>>> tensor([0.3243, 0.3513, 0.3243])
>>> tgm.image.get_gaussian_kernel(5, 1.5)
>>> tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
"""
if not isinstance(ksize, int) or ksize % 2 == 0 or ksize <= 0:
raise TypeError("ksize must be an odd positive integer. Got {}"
.format(ksize))
window_1d: torch.Tensor = gaussian(ksize, sigma)
return window_1d

def get_gaussian_kernel2d(ksize: Tuple[int, int],
sigma: Tuple[float, float]) -> torch.Tensor:
r"""Function that returns Gaussian filter matrix coefficients.
Args:
ksize (Tuple[int, int]): filter sizes in the x and y direction.
Sizes should be odd and positive.
sigma (Tuple[int, int]): gaussian standard deviation in the x and y
direction.
Returns:
Tensor: 2D tensor with gaussian filter matrix coefficients.
Shape:
- Output: :math:`(ksize_x, ksize_y)`
Examples::
>>> tgm.image.get_gaussian_kernel2d((3, 3), (1.5, 1.5))
>>> tensor([[0.0947, 0.1183, 0.0947],
[0.1183, 0.1478, 0.1183],
[0.0947, 0.1183, 0.0947]])
>>> tgm.image.get_gaussian_kernel((3, 5), (1.5, 1.5))
>>> tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
[0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
[0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
"""
if not isinstance(ksize, tuple) or len(ksize) != 2:
raise TypeError("ksize must be a tuple of length two. Got {}"
.format(ksize))
if not isinstance(sigma, tuple) or len(sigma) != 2:
raise TypeError("sigma must be a tuple of length two. Got {}"
.format(sigma))
ksize_x, ksize_y = ksize
sigma_x, sigma_y = sigma
kernel_x: torch.Tensor = get_gaussian_kernel(ksize_x, sigma_x)
kernel_y: torch.Tensor = get_gaussian_kernel(ksize_y, sigma_y)
kernel_2d: torch.Tensor = torch.matmul(
kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
return kernel_2d

class GaussianBlur(nn.Module):
r"""Creates an operator that blurs a tensor using a Gaussian filter.
The operator smooths the given tensor with a gaussian kernel by convolving
it to each channel. It suports batched operation.
Arguments:
kernel_size (Tuple[int, int]): the size of the kernel.
sigma (Tuple[float, float]): the standard deviation of the kernel.
Returns:
Tensor: the blurred tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H, W)`
Examples::
>>> input = torch.rand(2, 4, 5, 5)
>>> gauss = tgm.image.GaussianBlur((3, 3), (1.5, 1.5))
>>> output = gauss(input) # 2x4x5x5
"""
def __init__(self, kernel_size: Tuple[int, int], sigma: Tuple[float, float]) -> None:
super(GaussianBlur, self).__init__()
self.kernel_size: Tuple[int, int] = kernel_size
self.sigma: Tuple[float, float] = sigma
self._padding: Tuple[int, int] = self.compute_zero_padding(kernel_size)
self.kernel: torch.Tensor = self.create_gaussian_kernel(kernel_size, sigma)

@staticmethod
def create_gaussian_kernel(kernel_size, sigma) -> torch.Tensor:
"""Returns a 2D Gaussian kernel array."""
kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma)
return kernel

@staticmethod
def compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]:
"""Computes zero padding tuple."""
return tuple([(k - 1) // 2 for k in kernel_size])

def forward(self, x: torch.Tensor):
if not torch.is_tensor(x):
raise TypeError("Input x type is not a torch.Tensor. Got {}"
.format(type(x)))
if not len(x.shape) == 4:
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}"
.format(x.shape))
# prepare kernel
b, c, h, w = x.shape
kernel: torch.Tensor = self.kernel.to(x.device).to(x.dtype)
kernel: torch.Tensor = kernel.repeat(c, 1, 1, 1)

# convolve tensor with gaussian kernel
return conv2d(x, kernel, padding=self._padding, stride=1, groups=c)


######################
# functional interface
######################


def gaussian_blur(src: torch.Tensor, kernel_size: Tuple[int, int], sigma: Tuple[float, float]) -> torch.Tensor:
r"""Function that blurs a tensor using a Gaussian filter.
The operator smooths the given tensor with a gaussian kernel by convolving
it to each channel. It suports batched operation.
Arguments:
src (Tensor): the input tensor.
kernel_size (Tuple[int, int]): the size of the kernel.
sigma (Tuple[float, float]): the standard deviation of the kernel.
Returns:
Tensor: the blurred tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H, W)`
Examples::
>>> input = torch.rand(2, 4, 5, 5)
>>> output = tgm.image.gaussian_blur(input, (3, 3), (1.5, 1.5)) # 2x4x5x5
"""
return GaussianBlur(kernel_size, sigma)(src)

0 comments on commit 0ff6f5e

Please sign in to comment.