In [None]:
#default_exp preprocessing

In [None]:
#exporti
import torch
import warnings
from typing import Union

from dl4to.preprocessing import Preprocessing
from dl4to.utils import cast_to_solution

In [None]:
#hide
from nbdev.showdoc import show_doc

# Solution preprocessing

In [None]:
#export
class SolutionPreprocessing(Preprocessing):
    """
    A parent class for several solution preprocessing strategies.
    Solution preprocessing processes the data based on solution specific information like stresses, displacements and densities.
    """
    def __init__(self, 
                 name:str=None, # The name of the preprocessing.
                 normalize:bool=False # Whether to normalize the output of the preprocessing.
                ):
        super().__init__(preprocessing_type='solution', 
                         name=name, 
                         normalize=normalize)


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a problem object is passed, then it is automatically converted to its trivial solution via `problem.trivial_solution`.
        """
        raise NotImplementedError("Must be overridden.")

In [None]:
show_doc(SolutionPreprocessing.__call__)

<h4 id="SolutionPreprocessing.__call__" class="doc_header"><code>SolutionPreprocessing.__call__</code><a href="__main__.py#L16" class="source_link" style="float:right">[source]</a></h4>

> <code>SolutionPreprocessing.__call__</code>(**`problem_or_solution`**:`Union`\[`ForwardRef('dl4to.problem.Problem')`, `ForwardRef('dl4to.solution.Solution')`\])

Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
If a problem object is passed, then it is automatically converted to its trivial solution via `problem.trivial_solution`.

||Type|Default|Details|
|---|---|---|---|
|**`problem_or_solution`**|`typing.Union[ForwardRef('dl4to.problem.Problem'), ForwardRef('dl4to.solution.Solution')]`||A problem or solution object.|


In [None]:
#export
class PDEPreprocessing(SolutionPreprocessing):
    """
    PDE preprocessing [1, 2] computes the von Mises stresses for the trivial solution. We normalize the resulting tensor with $20\%$ of the yield stress to obtain outputs that are likely close to the unit interval. 
    These initial von Mises stresses are then used as a $1$-channel input to the neural network. It is also possible to use the displacements $u$ as input, or a concatenation of $u$ and $\sigma_{vM}$. 
    We found that using the von Mises stresses $\sigma_{vM}$ is usually enough.
    """
    def __init__(self, 
                 use_u:bool=False, # Whether to use the displacements in the preprocessing.
                 use_σ_vm:bool=True, # Whether to use the von Mises stresses in the preprocessing.
                 normalize:bool=False # Whether to normalize the forces in the output of the preprocessing. If True, then a dataset is required.
                ):
        self.use_u = use_u
        self.use_σ_vm = use_σ_vm
        if not any([self.use_u, self.use_σ_vm]):
            warnings.warn("At least one of `use_u` and `use_σ_vm` must be True.")
        super().__init__(name = "pde_preprocessing", normalize=normalize)


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a problem object is passed, then it is automatically converted to its trivial solution via `problem.trivial_solution`.
        """
        solution = cast_to_solution(problem_or_solution)
        u, _, σ_vm = solution.solve_pde()
        if self.normalize:
            σ_vm = σ_vm.clone() / .2 * solution.problem.σ_ys
            u = 1e2 * u.clone()
        output = []
        if self.use_u:
            output.append(u)
        if self.use_σ_vm:
            output.append(σ_vm)
        return torch.cat(output, dim=0).unsqueeze(0)


    def _get_shape(self):
        return 3 * self.use_u + self.use_σ_vm


    def _get_vector_directions(self):
        return ['x', 'y', 'z'] * self.use_u + [None] * self.use_σ_vm

In [None]:
#export
class DensityPreprocessing(SolutionPreprocessing):
    """
    A preprocessing that simply returns the density distribution of the solution objects as outputs.
    """
    def __init__(self, 
                 binary:bool=False, # Whether the density should be binarized.
                 normalize:bool=False # Whether to normalize the output of the preprocessing.
                ):
        super().__init__(name="density_preprocessing", normalize=normalize)
        self.binary = binary


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing to a problem or solution object. Returns a `torch.Tensor` object.
        If a problem object is passed, then it is automatically converted to its trivial solution via `problem.trivial_solution`.
        """
        solution = cast_to_solution(problem_or_solution)
        return solution.get_θ(binary=self.binary).unsqueeze(0)


    def _get_shape(self):
        return 1


    def _get_vector_directions(self):
        return [None]

# References

[1] Dittmer, Sören, et al. "SELTO: Sample-Efficient Learned Topology Optimization." arXiv preprint arXiv:2209.05098 (2022).

[2] Zhang, Yiquan, et al. "A deep convolutional neural network for topology optimization with strong generalization ability." arXiv preprint arXiv:1901.07761 (2019).

In [None]:
#hide
from dl4to.datasets import BasicDataset
from dl4to.pde import FDM

In [None]:
%%time
#hide

def test_that_calling_PDE_preprocessing_with_problem_does_the_same_like_when_calling_with_trivial_solution():
    problem = BasicDataset(resolution=15).ledge()
    preprocessing = PDEPreprocessing()
    problem.pde_solver = FDM()
    assert torch.all(preprocessing(problem) == preprocessing(problem.trivial_solution))

test_that_calling_PDE_preprocessing_with_problem_does_the_same_like_when_calling_with_trivial_solution()

CPU times: user 1.54 s, sys: 130 ms, total: 1.67 s
Wall time: 502 ms


In [None]:
%%time
#hide

def test_that_shape_is_correct():
    problem = BasicDataset(resolution=15).ledge()

    preprocessing = PDEPreprocessing()
    problem.pde_solver = FDM()
    assert preprocessing(problem).shape[0] == 1
    assert preprocessing(problem).shape[1] == preprocessing.shape


    preprocessing = DensityPreprocessing()
    assert preprocessing(problem).shape[0] == 1
    assert preprocessing(problem).shape[1] == preprocessing.shape



test_that_shape_is_correct()

CPU times: user 1.52 s, sys: 31.1 ms, total: 1.55 s
Wall time: 440 ms
