In [None]:
#default_exp models

In [None]:
#exporti
import torch
import torch.nn as nn

In [None]:
#hide
from nbdev.showdoc import show_doc

# Deep Image Prior

In [None]:
#export
class DeepImagePrior(nn.Module):
    """
    The deep image prior (DIP) [1] is a type of convolutional neural network used to enhance a given image with no prior training data other than the image itself.
    A neural network is randomly initialized and used as prior to solve inverse problems such as noise reduction, super-resolution, and inpainting.
    Image statistics are captured by the structure of a convolutional image generator rather than by any previously learned capabilities.
    """
    def __init__(self, 
                 shape:list, # A list containing three entries that define the number of voxels in each direction.
                 n_channels:int=1, # The number of input channels.
                 n_inital_channels:int=4 #T he number of channels after the first encoding block. The model has a total 4 encoding and 4 decoding blocks, and the number of channels is doubled in each encoding step.
                ):
        super().__init__()
        self.encoder_layers = nn.ModuleList()
        self.n_channels = n_channels

        self.encoder_layers.append(nn.Conv3d(n_channels,    1*n_inital_channels, 3, stride=2, padding=1))
        self.encoder_layers.append(nn.Conv3d(1*n_inital_channels, 2*n_inital_channels, 3, stride=2, padding=1))
        self.encoder_layers.append(nn.Conv3d(2*n_inital_channels, 4*n_inital_channels, 3, stride=2, padding=1))
        self.encoder_layers.append(nn.Conv3d(4*n_inital_channels, 8*n_inital_channels, 3, stride=2, padding=1))

        self.decoder_layers = nn.ModuleList()
        self.decoder_layers.append(nn.Conv3d(8*n_inital_channels, 4*n_inital_channels, 3, stride=1, padding=1))
        self.decoder_layers.append(nn.Conv3d(4*n_inital_channels, 2*n_inital_channels, 3, stride=1, padding=1))
        self.decoder_layers.append(nn.Conv3d(2*n_inital_channels, 1*n_inital_channels, 3, stride=1, padding=1))
        self.decoder_layers.append(nn.Conv3d(1*n_inital_channels,    n_channels, 3, stride=1, padding=1))

        self.relu = nn.ReLU()
        self.z = torch.randn(n_channels, *shape, requires_grad=False)


    def forward(self):
        """
        The forward pass of the DIP with a fixed random noise input. Returns a `torch.Tensor` object.
        """
        encoder_activations = [self.z.view(1, *self.z.shape)]

        for encoder_layer in self.encoder_layers[:-1]:
            activation = encoder_activations[-1]
            activation = nn.functional.layer_norm(activation, encoder_activations[-1].shape[1:])
            activation = encoder_layer(activation)
            activation = self.relu(activation)
            encoder_activations.append(activation)

        central_activation = self.relu(self.encoder_layers[-1](encoder_activations[-1]))

        decoder_activations = [central_activation]

        for idx, (decoder_layer, encoder_activation) in enumerate(zip(self.decoder_layers, encoder_activations[::-1])):
            activation = decoder_activations[-1]
            activation = nn.functional.layer_norm(activation, activation.shape[1:])
            activation = nn.functional.interpolate(activation, size=encoder_activation.shape[2:])
            activation = decoder_layer(activation)
            activation = self.relu(activation) if idx != len(self.decoder_layers) - 1 else torch.sigmoid(activation)
            decoder_activations.append(activation)

        return decoder_activations[-1].squeeze(0)

In [None]:
show_doc(DeepImagePrior.forward)

<h4 id="DeepImagePrior.forward" class="doc_header"><code>DeepImagePrior.forward</code><a href="__main__.py#L32" class="source_link" style="float:right">[source]</a></h4>

> <code>DeepImagePrior.forward</code>()

The forward pass of the DIP with a fixed random noise input. Returns a `torch.Tensor` object.

# References

[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Deep image prior." Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.

In [None]:
#hide
import hypothesis.strategies as st
from hypothesis import given, settings

In [None]:
#hide
st_n_channels = st.integers(min_value=1, max_value=10)
st_shape = st.tuples(st.integers(min_value=5, max_value=50),
                     st.integers(min_value=5, max_value=50),
                     st.integers(min_value=5, max_value=50))

In [None]:
%%time
#hide

@given(n_channels=st_n_channels, shape=st_shape)
@settings(max_examples=5, deadline=None)
def test_shapes(n_channels, shape):
        dip = DeepImagePrior(shape=shape, n_channels=n_channels)
        θ = dip()
        assert θ.shape == (n_channels, *shape)


test_shapes()

CPU times: user 421 ms, sys: 114 ms, total: 536 ms
Wall time: 72.6 ms
