In [None]:
#default_exp density_filters

In [None]:
#exporti
import torch
import warnings
import numpy as np
from torch.nn import Conv3d, Module

In [None]:
#hide
from nbdev.showdoc import show_doc

# Density filters

In [None]:
#export
class DensityFilter(Module):
    """
    A parent class that inherits several different filters for smoothing.
    """
    def __init__(self, 
                 filter_size:int, # The size of the filter.
                 dtype:torch.dtype=torch.float32 # The datatype of the filter.
                ):
        self._filter_size = filter_size
        self.dtype = dtype
        super().__init__()


    @property
    def filter_size(self):
        return self._filter_size


    def _filtering(self, θ):
        raise NotImplementedError("Must be overridden.")


    def __call__(self,
                 θ:torch.Tensor # The input of the filter.
                ):
        """
        Apply the filtering to the input. Returns a `torch.Tensor`.
        """
        θ = self._filtering(θ)
        assert torch.all(0 <= θ)
        assert torch.all(θ <= 1)
        return θ

In [None]:
show_doc(DensityFilter.__call__)

<h4 id="DensityFilter.__call__" class="doc_header"><code>DensityFilter.__call__</code><a href="__main__.py#L24" class="source_link" style="float:right">[source]</a></h4>

> <code>DensityFilter.__call__</code>(**`θ`**:`Tensor`)

Apply the filtering to the input. Returns a `torch.Tensor`.

||Type|Default|Details|
|---|---|---|---|
|**`θ`**|`Tensor`||The input of the filter.|


In [None]:
#export
class MaxPoolDensityFilter(DensityFilter):
    """
    A filter that applies max pooling.
    """
    def __init__(self, 
                 filter_size:int, # The size of the filter.
                 dtype:torch.dtype=torch.float32 # The datatype of the filter.
                ):
        super().__init__(filter_size, dtype)


    def _filtering(self, θ):
        θ = torch.nn.functional.max_pool3d(θ, kernel_size=self._filter_size, stride=1, padding=self._filter_size//2)

        if self._filter_size % 2:
            return θ

        θ = torch.nn.functional.interpolate(θ, size=θ.shape[2:], mode='nearest')
        return θ

In [None]:
#export
class ConvolutionDensityFilter(DensityFilter):
    """
    A parent class that inherits convolutional filters.
    """
    def __init__(self, 
                 filter_size:int, # The size of the filter.
                 dtype:torch.dtype=torch.float32 # The datatype of the filter.
                ):
        super().__init__(filter_size, dtype)
        self.kernel = self._get_kernel()
        self.conv = Conv3d(
            in_channels=1,
            out_channels=1,
            kernel_size=3*[filter_size],
            padding=int((filter_size-1)/2),
            padding_mode='replicate',
            bias=False,
            dtype=dtype,
        )
        self.conv.weight.data = self.kernel.clone()
        self.conv.requires_grad_(False)


    def _normalize_kernel(self, kernel):
        assert torch.all(0 <= kernel)
        kernel = kernel / kernel.sum()
        assert torch.all(kernel <= 1)
        assert torch.all(kernel >= 0)
        return kernel


    def _get_kernel(self):
        raise NotImplementedError("Must be overridden.")


    def _filtering(self, θ):
        assert torch.all(self.conv.weight.data <= 1)
        assert torch.all(self.conv.weight.data >= 0)
        assert torch.allclose(self.conv.weight.data, self.kernel)
        return self.conv(θ)

In [None]:
#export
class UniformDensityFilter(ConvolutionDensityFilter):
    """
    A class that performs convolution with a uniform filter, which is also refered to as mean pooling.
    """
    def __init__(self, 
                 filter_size:int, # The size of the filter.
                 dtype:torch.dtype=torch.float32 # The datatype of the filter.
                ):
        if not filter_size % 2:
            filter_size += 1
            warnings.warn(f"filter_size must be an even number. Automatically setting filter_size to {filter_size}.")

        super().__init__(filter_size, dtype)


    def _get_kernel(self):
        kernel_ = torch.ones(self.filter_size, self.filter_size, dtype=self.dtype)
        kernel = torch.stack(self.filter_size * [kernel_])
        return self._normalize_kernel(kernel).unsqueeze(0).unsqueeze(0)

In [None]:
#export
class RadialDensityFilter(ConvolutionDensityFilter):
    """
    A class that performs convolution with a radial filter. A radial filter is a filter that has its maximal value in the center and decays radially to the outside. 
    All values of the filter sum up to one.
    """
    def __init__(self, 
                 filter_size:int, # The size of the filter.
                 dtype:torch.dtype=torch.float32 # The datatype of the filter.
                ):
        super().__init__(filter_size, dtype)


    def _get_kernel(self):
        filter_size = self.filter_size + 2
        r = filter_size // 2
        kernel = torch.zeros(3 * [filter_size], dtype=self.dtype)
        center = torch.ones(3, dtype=self.dtype) * r

        for i in range(filter_size):
            for j in range(filter_size):
                for k in range(filter_size):
                    position = torch.tensor([i, j, k], dtype=self.dtype)
                    dist = torch.norm(center - position, p=1)
                    kernel[i,j,k] = torch.relu(r - dist)

        kernel = kernel[1:-1, 1:-1, 1:-1]
        return self._normalize_kernel(kernel).unsqueeze(0).unsqueeze(0)

In [None]:
%%time
#hide

def test_uniform_density_filter():
    for filter_size in [3, 5, 7, 9]:
        density_filter = UniformDensityFilter(filter_size)
        θ = torch.rand(1, 1, 10, 10, 10)
        θ = density_filter(θ)

        assert np.allclose(θ.shape, θ.shape)
        assert density_filter.kernel.shape == (1, 1, filter_size, filter_size, filter_size)
        assert torch.all(density_filter.kernel == density_filter.kernel[0,0,0,0,0])


test_uniform_density_filter()

CPU times: user 207 ms, sys: 12.4 ms, total: 220 ms
Wall time: 41.2 ms


In [None]:
%%time
#hide

def test_radial_density_filter():
    for filter_size in [3, 5, 7, 9]:
        density_filter = RadialDensityFilter(filter_size)
        θ = torch.rand(1, 1, 10, 10, 10)
        θ = density_filter(θ)

        assert np.allclose(θ.shape, θ.shape)
        assert density_filter.kernel.shape == (1, 1, filter_size, filter_size, filter_size)
        assert density_filter.kernel.argmax() == int(filter_size**3 / 2)


test_radial_density_filter()

CPU times: user 743 ms, sys: 0 ns, total: 743 ms
Wall time: 141 ms


In [None]:
%%time
#hide

def test_max_pool_density_filter():
    for filter_size in [1, 2, 3, 4, 5]:
        density_filter = MaxPoolDensityFilter(filter_size)
        θ = torch.rand(1, 1, 10, 10, 10)
        θ = density_filter(θ)

        assert np.allclose(θ.shape, θ.shape), f"{filter_size}, {θ.shape}"


test_max_pool_density_filter()

CPU times: user 8.76 ms, sys: 4.17 ms, total: 12.9 ms
Wall time: 3.49 ms
