In [None]:
#default_exp density_representers

# Filtering density representer

In [None]:
#exporti
import torch

from dl4to.density_representers import DensityRepresenter
from dl4to.density_filters import UniformDensityFilter, RadialDensityFilter, MaxPoolDensityFilter
import warnings

In [None]:
#export

filter_fcts = ['radial', 'uniform', 'max_pool', None]

class FilteringDensityRepresenter(DensityRepresenter):
    """
    A density representer that applies filtering to its latent density distribution.
    """
    def __init__(self, 
                 problem:"dl4to.problem.Problem"=None, # The problem object for which the density representer is used. The problem object is necessary to grant that boundary and design space constraints are fulfilled. However, the problem does not need to be passed during initializiaton but can also be passed later by overriding `density_representer.problem`.
                 filter_size:int=3, # The size of the filter kernel.
                 filter_fct:str='radial', # The type of filtering strategy that is used. Possible options are "radial", "uniform", "max_pool" and None.
                 binarizer_strength:float=1., #  The steepness of the smoothed Heaviside-function. A binarizer strength of infinity would corresponds to a non-smooth classical Heaviside step function.
                 θ_default:float=.5 # The weighting factor for the trivial solution density that is used as the initialization of the latent density distribution.
                ):
        super().__init__(
            problem=problem,
            binarizer_strength=binarizer_strength
        )
        self.filter_size = filter_size
        self.filter_fct = filter_fct
        assert self.filter_fct in filter_fcts
        self.θ_default = θ_default


    def _setup_filter(self):
        if self.filter_fct == 'radial':
            self.filter = RadialDensityFilter(filter_size=self.filter_size, dtype=self.problem.dtype)
        elif self.filter_fct == 'uniform':
            self.filter = UniformDensityFilter(filter_size=self.filter_size, dtype=self.problem.dtype)
        elif self.filter_fct == 'max_pool':
            self.filter = MaxPoolDensityFilter(filter_size=self.filter_size, dtype=self.problem.dtype)
        else:
            self.filter = lambda θ: θ


    def _setup_for_problem(self):
        θ = torch.ones(1, *self.problem.shape, dtype=self.problem.dtype) * self.θ_default
        self.θ = torch.nn.Parameter(θ, requires_grad=True)
        self._setup_filter()


    def _apply_density_representer(self):
        self.θ.data.clamp_(0, 1)

        self.θ.data[self.problem.Ω_design == 0] = 0.
        self.θ.data[self.problem.Ω_design == 1] = 1.

        θ = self.filter(self.θ.unsqueeze(0)).squeeze(0)

        if θ.max().item() > 1.1 or θ.min().item() < -0.1:
            warnings.warn("Density value is too large or too small")

        return θ.clamp(0, 1)

In [None]:
#hide
from dl4to.datasets import BasicDataset

In [None]:
%%time
#hide

def test_shapes():
    problem = BasicDataset().ledge()
    representer = FilteringDensityRepresenter(filter_size=3, filter_fct='radial')
    representer.problem = problem
    θ = representer()
    assert θ.shape == (1, *problem.shape)

    representer = FilteringDensityRepresenter(filter_size=3, filter_fct='uniform')
    representer.problem = problem
    θ = representer()
    assert θ.shape == (1, *problem.shape)

    assert torch.all(0 <= θ)
    assert torch.all(θ <= 1)


test_shapes()

CPU times: user 411 ms, sys: 0 ns, total: 411 ms
Wall time: 85.5 ms
