In [None]:
#default_exp preprocessing

In [None]:
#exporti
import torch
from typing import Union

In [None]:
#hide
from nbdev.showdoc import show_doc

# Preprocessing

In [None]:
#export
class Preprocessing():
    """
    A parent class for all data preprocessing strategies.
    """
    def __init__(self, 
                 preprocessing_type:str, # The type of the preprocessing, which can either be "problem" or "solution". Problem preprocessing preprocesses the data based on problem information like forces etc., while solution preprocessing uses solution specific information like stresses, densities etc.
                 name:str=None, # The name of the preprocessing.
                 normalize:bool=False # Whether to normalize the output of the preprocessing.
                ):
        self.preprocessing_type = preprocessing_type
        self.name = name
        self._normalize = normalize
        self._shape = self._get_shape()
        self._vector_directions = self._get_vector_directions()
        assert len(self.vector_directions) == self.shape


    @property
    def normalize(self):
        return self._normalize


    @property
    def shape(self):
        return self._shape


    @property
    def vector_directions(self):
        return self._vector_directions


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies the preprocessing and returns a `torch.Tensor` object.
        """
        raise NotImplementedError("Must be overridden.")


    def _get_shape(self):
        raise NotImplementedError("Must be overridden.")


    def _get_vector_directions(self):
        raise NotImplementedError("Must be overridden.")


    def __add__(self,
                preprocessing:"dl4to.preprocessing.Preprocessing" # The preprocessing that should be added to the current one.
               ):
        """
        Summation of two preprocessings results in a new combined preprocessing that concatenates the output of both. Returns a `dl4to.preprocessing.CombinedPreprocessing` object.
        """
        return CombinedPreprocessing(self, preprocessing)

In [None]:
show_doc(Preprocessing.__call__)

<h4 id="Preprocessing.__call__" class="doc_header"><code>Preprocessing.__call__</code><a href="__main__.py#L34" class="source_link" style="float:right">[source]</a></h4>

> <code>Preprocessing.__call__</code>(**`problem_or_solution`**:`Union`\[`ForwardRef('dl4to.problem.Problem')`, `ForwardRef('dl4to.solution.Solution')`\])

Applies the preprocessing and returns a `torch.Tensor` object.

||Type|Default|Details|
|---|---|---|---|
|**`problem_or_solution`**|`typing.Union[ForwardRef('dl4to.problem.Problem'), ForwardRef('dl4to.solution.Solution')]`||A problem or solution object.|


In [None]:
show_doc(Preprocessing.__add__)

<h4 id="Preprocessing.__add__" class="doc_header"><code>Preprocessing.__add__</code><a href="__main__.py#L51" class="source_link" style="float:right">[source]</a></h4>

> <code>Preprocessing.__add__</code>(**`preprocessing`**:`dl4to.preprocessing.Preprocessing`)

Summation of two preprocessings results in a new combined preprocessing that concatenates the output of both. Returns a `dl4to.preprocessing.CombinedPreprocessing` object.

||Type|Default|Details|
|---|---|---|---|
|**`preprocessing`**|`dl4to.preprocessing.Preprocessing`||The preprocessing that should be added to the current one.|


In [None]:
#export
class CombinedPreprocessing(Preprocessing):
    """
    A class that results from the summation of two criteria.
    """
    def __init__(self, 
                 preprocessing1:"dl4to.preprocessing.Preprocessing", # The first preprocessing.
                 preprocessing2:"dl4to.preprocessing.Preprocessing", # The second preprocessing.
                ):
        self.preprocessing1 = preprocessing1
        self.preprocessing2 = preprocessing2
        name = f"{preprocessing1.name}_plus_{preprocessing2.name}"
        preprocessing_type = f"{preprocessing1.preprocessing_type}_plus_{preprocessing2.preprocessing_type}"
        super().__init__(preprocessing_type=preprocessing_type, name=name)


    def __call__(self,
                 problem_or_solution:Union["dl4to.problem.Problem","dl4to.solution.Solution"] # A problem or solution object.
                ):
        """
        Applies both preprocessing to a problem or solution object and performs channel-wise concatenation of their outputs. Returns a `torch.Tensor` object.
        """
        output1 = self.preprocessing1(problem_or_solution)
        output2 = self.preprocessing2(problem_or_solution)
        return torch.cat([output1, output2], dim=1)


    def _get_shape(self):
        return self.preprocessing1.shape + self.preprocessing2.shape


    def _get_vector_directions(self):
        return self.preprocessing1.vector_directions + self.preprocessing2.vector_directions