In [None]:
#default_exp preprocessing

In [None]:
#exporti
import torch
import warnings
from skimage.morphology import convex_hull_image
from typing import Union

from dl4to.preprocessing import Preprocessing
from dl4to.utils import cast_to_problem

In [None]:
#hide
from nbdev.showdoc import show_doc

# Problem preprocessing

In [None]:
#export
class ProblemPreprocessing(Preprocessing):
    """
    A parent class for several problem preprocessing strategies. 
    Problem preprocessing processes the data based on problem specific information like forces and boundary conditions.
    """
    def __init__(self, 
                 name:str=None, # The name of the preprocessing.
                 normalize:bool=False # Whether to normalize the output of the preprocessing.
                ):
        super().__init__(preprocessing_type='problem', 
                         name=name, 
                         normalize=normalize)


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a solution object is passed, then it is automatically converted to a problem object via `solution.problem`.

        """
        raise NotImplementedError("Must be overridden.")

In [None]:
show_doc(ProblemPreprocessing.__call__)

<h4 id="ProblemPreprocessing.__call__" class="doc_header"><code>ProblemPreprocessing.__call__</code><a href="__main__.py#L16" class="source_link" style="float:right">[source]</a></h4>

> <code>ProblemPreprocessing.__call__</code>(**`problem_or_solution`**:`Union`\[`ForwardRef('dl4to.problem.Problem')`, `ForwardRef('dl4to.solution.Solution')`\])

Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
If a solution object is passed, then it is automatically converted to a problem object via `solution.problem`.

||Type|Default|Details|
|---|---|---|---|
|**`problem_or_solution`**|`typing.Union[ForwardRef('dl4to.problem.Problem'), ForwardRef('dl4to.solution.Solution')]`||A problem or solution object.|


In [None]:
#export
class TrivialPreprocessing(ProblemPreprocessing):
    """
    The output of trivial preprocessing [1] is a 7-channel tensor which results from the channel-wise concatenation of Dirichlet boundary conditions, design space information and loads. 
    It is possible to normalize each sample’s F via the infinity norm over all forces in a dataset.
    """
    def __init__(self, 
                 normalize:bool=False, # Whether to normalize the forces in the output of the preprocessing. If True, then a dataset is required.
                 dataset:"dl4to.datasets.TopoDataset"=None # A dataset that is used for the normalization of the forces in the output of the preprocessing. Is only used if `normalize=True`.
                ):
        super().__init__(name="trivial_preprocessing", normalize=normalize)
        self.L_inf_norm_F = 1.
        if self.normalize:
            if dataset is None:
                warnings.warn("Batch normalization is only possible if a dataset is provided.")
                self._normalize = False
            else:
                self.set_normalization_constant(dataset)


    def set_normalization_constant(self, dataset):
        max_forces = []
        for i in range(len(dataset)):
            problem, _ = dataset[i]
            max_forces.append(problem.F.abs().max())

        if len(max_forces) == 0:
            return 1
        self.L_inf_norm_F = sum(max_forces) / len(max_forces)


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a solution object is passed, then it is automatically converted to a problem object via `solution.problem`.
        """
        problem = cast_to_problem(problem_or_solution)
        F = problem.F / self.L_inf_norm_F
        return torch.cat([problem.Ω_dirichlet, problem.Ω_design, F]).unsqueeze(0)


    def _get_shape(self):
        return 7


    def _get_vector_directions(self):
        return [None, None, None, None, 'x', 'y', 'z']

In [None]:
#export
class ForcePreprocessing(ProblemPreprocessing):
    """
    The input of the neural network is the 3-channel force tensor F.
    It is possible to normalize each sample’s F via the mean absolute maximum over a dataset.
    """
    def __init__(self, 
                 normalize:bool=False, # Whether to normalize the forces in the output of the preprocessing. If True, then a dataset is required.
                 dataset:"dl4to.datasets.TopoDataset"=None # A dataset that is used for the normalization of the forces in the output of the preprocessing. Is only used if `normalize=True`.
                ):
        super().__init__(name="force_preprocessing", normalize=normalize)
        self.L_inf_norm_F = 1.
        if self.normalize:
            if dataset is None:
                warnings.warn("Batch normalization is only possible if a dataset is provided.")
                self._normalize = False
            else:
                self.set_normalization_constant(dataset)


    def set_normalization_constant(self, dataset):
        max_forces = []
        for problem, _ in dataset:
            max_forces.append(problem.F.abs().max())
        if len(max_forces) == 0:
            return 1
        self.L_inf_norm_F = sum(max_forces) / len(max_forces)


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a solution object is passed, then it is automatically converted to a problem object via `solution.problem`.
        """
        problem = cast_to_problem(problem_or_solution)
        F = problem.F / self.L_inf_norm_F
        return F.unsqueeze(0)


    def _get_shape(self):
        return 3


    def _get_vector_directions(self):
        return ['x', 'y', 'z']

In [None]:
#export
class ConvexHullPreprocessing(ProblemPreprocessing):
    """
    The convex hull of a binary image is the set of pixels included in the smallest convex polygon that surround all white pixels in the input.
    Convex hull preprocessing [1] generalizes this to 3d voxels and constructs a polygon with density 1 that connects the force allocation points to points with homogeneous Dirichlet boundary conditions.
    This binary density polygon is the output of this preprocessing.

    """
    def __init__(self):
        super().__init__(name="convex_hull_preprocessing")


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a solution object is passed, then it is automatically converted to a problem object via `solution.problem`.
        """
        problem = cast_to_problem(problem_or_solution)
        Ω_design = problem.Ω_design
        θ = (Ω_design == 1).type(torch.float32)
        assert len(θ.shape) == 4, f"{θ.shape=}"
        θ = torch.tensor(convex_hull_image(θ[0].numpy()))
        θ = θ.unsqueeze(0)
        assert len(θ.shape) == 4, f"{θ.shape=}"
        return θ.unsqueeze(0).type(torch.float32)


    def _get_shape(self):
        return 1


    def _get_vector_directions(self):
        return [None]

# References

[1] Dittmer, Sören, et al. "SELTO: Sample-Efficient Learned Topology Optimization." arXiv preprint arXiv:2209.05098 (2022).

In [None]:
#hide
from dl4to.datasets import BasicDataset

In [None]:
%%time
#hide

def test_that_shape_is_correct():
    problem = BasicDataset(resolution=15).ledge()
    preprocessing = TrivialPreprocessing()
    assert preprocessing(problem).shape[0] == 1
    assert preprocessing(problem).shape[1] == preprocessing.shape

    preprocessing = ForcePreprocessing()
    assert preprocessing(problem).shape[0] == 1
    assert preprocessing(problem).shape[1] == preprocessing.shape

    preprocessing = ConvexHullPreprocessing()
    assert preprocessing(problem).shape[0] == 1
    assert preprocessing(problem).shape[1] == preprocessing.shape


test_that_shape_is_correct()

In [None]:
%%time
#hide

def test_that_batch_normalization_is_working():
    problem = BasicDataset().ledge()
    preprocessing_normalized = TrivialPreprocessing(normalize=True, dataset=[[problem, _]])
    preprocessing_unnormalized = TrivialPreprocessing(normalize=False)
    assert preprocessing_normalized(problem).abs().max() == 1
    assert preprocessing_unnormalized(problem).abs().max() > 1e5

    preprocessing_normalized = ForcePreprocessing(normalize=True, dataset=[[problem, _]])
    preprocessing_unnormalized = ForcePreprocessing(normalize=False)
    assert preprocessing_normalized(problem).abs().max() == 1
    assert preprocessing_unnormalized(problem).abs().max() > 1e5


test_that_batch_normalization_is_working()