In [None]:
#default_exp density_representers

In [None]:
#exporti
import torch

from dl4to.models import DeepImagePrior
from dl4to.density_representers import DensityRepresenter

# DIP density representer

In [None]:
#export
class DeepImagePriorDensityRepresenter(DensityRepresenter):
    """
    A density representer that contains a DIP module and performs a forward pass with a noise input when being called. The idea is adapted from [1].
    """
    def __init__(self, 
                 problem:"dl4to.problem.Problem"=None, # The problem object for which the density representer is used. The problem object is necessary to grant that boundary and design space constraints are fulfilled. However, the problem does not need to be passed during initializiaton but can also be passed later by overriding `density_representer.problem`.
                 binarizer_strength:float=1. # The steepness of the smoothed Heaviside-function. A binarizer strength of infinity would corresponds to a non-smooth classical Heaviside step function.
                ):
        super().__init__(
            problem=problem,
            binarizer_strength=binarizer_strength
        )


    def _setup_for_problem(self):
        self.dip = DeepImagePrior(self.problem.shape)


    def _apply_density_representer(self):
        return self.dip()

# References

[1] Hoyer, Stephan, Jascha Sohl-Dickstein, and Sam Greydanus. "Neural reparameterization improves structural optimization." arXiv preprint arXiv:1909.04240 (2019).

In [None]:
#hide
from dl4to.datasets import BasicDataset

In [None]:
%%time
#hide

def test_shapes_and_properties():
    problem = BasicDataset().ledge()
    representer = DeepImagePriorDensityRepresenter(problem)

    representer.binarizer_strength = 1.
    θ1 = representer()

    assert θ1.shape == (1, *problem.shape)

    representer.binarizer_strength = 2.
    θ2 = representer()

    assert not torch.all(θ1 == θ2)
    assert torch.all(θ2[θ1 < .5] <= θ1[θ1 < .5])
    assert torch.all(θ2[θ1 > .5] >= θ1[θ1 > .5])

    assert torch.all(0 <= θ1)
    assert torch.all(θ1 <= 1)

    assert torch.all(0 <= θ2)
    assert torch.all(θ2 <= 1)


test_shapes_and_properties()

CPU times: user 3.21 s, sys: 4.03 ms, total: 3.21 s
Wall time: 416 ms
