In [None]:
#default_exp models

In [None]:
#exporti
import torch
import torch.nn as nn
from typing import Union

from dl4to.models import ConvolutionalBlock, Encoder, EncodingBlock, Decoder

In [None]:
#hide
from nbdev.showdoc import show_doc

# UNet

In [None]:
#export
class UNet(nn.Module):
    """
    UNets are convolutional autoencoders that were developed for biomedical image segmentation at the Computer Science Department of the University of Freiburg [1].
    The network is based on a fully convolutional network and its architecture was modified and extended to work with fewer training images and to yield more precise segmentations. 
    Our code based on `https://github.com/fepegar/unet/tree/master/unet`.
    """
    def __init__(
        self,
        in_channels:int=7, # The number of input channels.
        out_classes:int=1, # The number of output classes.
        dimensions:int=3, # The number of dimensions to consider. Possible options are 2 and 3.
        num_encoding_blocks:int=4, # The number of encoding blocks.
        out_channels_first_layer:int=10, # The number of output channels after the first encoding step.
        normalization:str='batch', # The type of normalization to use. Possible options include "batch", "layer" and "instance".
        pooling_type:str='max', # The type of pooling to use.
        upsampling_type:str='nearest', # The type of upsampling to use.
        preactivation:bool=False, # Whether to use preactivations.
        residual:bool=False, # Whether the encoder should be a residual network.
        use_padding:bool=True, # Whether to use padding.
        padding_mode:str='zeros', # The type of padding to use.
        activation:str='ReLU', # The activation function that should be used.
        classifier_activation:str='Sigmoid', # The activation function for the classifier at the end of the UNet.
        initial_dilation:int=None, # The amount of dilation that should be used in the first encoding block.
        dropout:float=0., # The dropout rate.
        upsample_recover_orig_size:bool=True, # Whether the original input size of the encoder should be recovered with the decoder output.
        pooling_kernel_size:Union[int,list]=[3, 3, 3], # The size of the pooling kernel.
        use_classifier:bool=True, # Whether to use a classifier layer at the end of the network.
        verbose:bool=True # Whether to print the user information on the neural network, for instance the number of parameters.
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_classes = out_classes
        self.dimensions = dimensions
        self.num_encoding_blocks = num_encoding_blocks
        self.out_channels_first_layer = out_channels_first_layer
        self.normalization = normalization
        self.pooling_type = pooling_type
        self.upsampling_type = upsampling_type
        self.preactivation = preactivation
        self.residual = residual
        self.use_padding = use_padding
        self.padding_mode = padding_mode
        self.activation = activation
        self.classifier_activation = classifier_activation
        self.initial_dilation = initial_dilation
        self.dropout = dropout
        self.upsample_recover_orig_size = upsample_recover_orig_size
        self.pooling_kernel_size = pooling_kernel_size
        self.use_classifier = use_classifier
        self.verbose = verbose

        self.depth = self.num_encoding_blocks - 1

        if residual:
            self.use_padding = True

        if len(pooling_kernel_size) != dimensions:
            raise ValueError('Length of pooling_kernel_size and dimension do not coincide!')


        self.encoder = self._get_encoder()
        self.bottom_block = self._get_bottom_block()
        self.decoder = self._get_decoder()

        if self.use_classifier:
            self.classifier = self._get_classifier()


        if verbose:
            n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
            print(f'Built model with {n_params} parameters.')


    def _get_encoder(self):
        encoder = Encoder(
            in_channels=self.in_channels,
            out_channels_first=self.out_channels_first_layer,
            dimensions=self.dimensions,
            pooling_type=self.pooling_type,
            num_encoding_blocks=self.depth,
            normalization=self.normalization,
            preactivation=self.preactivation,
            residual=self.residual,
            use_padding=self.use_padding,
            padding_mode=self.padding_mode,
            activation=self.activation,
            initial_dilation=self.initial_dilation,
            dropout=self.dropout,
            pooling_kernel_size=self.pooling_kernel_size
        )
        return encoder


    def _get_bottom_block(self):
        out_channels_first = self.encoder.out_channels

        if self.dimensions == 2:
            out_channels_first = 2 * out_channels_first

        bottom_block = EncodingBlock(
            in_channels=self.encoder.out_channels,
            out_channels_first=out_channels_first,
            dimensions=self.dimensions,
            normalization=self.normalization,
            pooling_type=None,
            preactivation=self.preactivation,
            residual=self.residual,
            use_padding=self.use_padding,
            padding_mode=self.padding_mode,
            activation=self.activation,
            dilation=self.encoder.dilation,
            dropout=self.dropout,
            pooling_kernel_size=self.pooling_kernel_size
        )
        return bottom_block


    def _get_decoder(self):
        power = self.depth

        if self.dimensions == 2:
            power = power - 1

        decoder = Decoder(
            in_channels_skip_connection=self.out_channels_first_layer * 2**power,
            dimensions=self.dimensions,
            upsampling_type=self.upsampling_type,
            num_decoding_blocks=self.depth,
            normalization=self.normalization,
            preactivation=self.preactivation,
            residual=self.residual,
            use_padding=self.use_padding,
            padding_mode=self.padding_mode,
            activation=self.activation,
            initial_dilation=self.encoder.dilation,
            dropout=self.dropout,
            upsample_recover_orig_size=self.upsample_recover_orig_size
        )
        return decoder


    def _get_classifier(self):
        in_channels = self.bottom_block.out_channels

        if self.dimensions == 2:
            in_channels = self.out_channels_first_layer
        elif self.dimensions == 3:
            in_channels = 2 * self.out_channels_first_layer

        classifier = ConvolutionalBlock(
            dimensions=self.dimensions,
            in_channels=in_channels,
            out_channels=self.out_classes,
            kernel_size=1,
            activation=self.classifier_activation
        )
        return classifier


    def forward(self, 
                model_inputs:torch.Tensor # The input to the UNet.
               ):
        """
        The forward pass of the UNet. Returns a `torch.Tensor` object.
        """
        skip_connections, encoding = self.encoder(model_inputs)
        encoding = self.bottom_block(encoding)

        output = self.decoder(skip_connections, encoding)

        if self.use_classifier:
            return self.classifier(output)
        return output

In [None]:
show_doc(UNet.forward)

<h4 id="UNet.forward" class="doc_header"><code>UNet.forward</code><a href="__main__.py#L162" class="source_link" style="float:right">[source]</a></h4>

> <code>UNet.forward</code>(**`model_inputs`**:`Tensor`)

The forward pass of the UNet. Returns a `torch.Tensor` object.

||Type|Default|Details|
|---|---|---|---|
|**`model_inputs`**|`Tensor`||The input to the UNet.|


In [None]:
#export
class UNet3D(UNet):
    """
    A 3d version of our UNet. UNets are convolutional autoencoders that were developed for biomedical image segmentation at the Computer Science Department of the University of Freiburg [1].
    The network is based on a fully convolutional network and its architecture was modified and extended to work with fewer training images and to yield more precise segmentations. 
    Our code based on `https://github.com/fepegar/unet/tree/master/unet`.
    """
    def __init__(self, *args, **user_kwargs):
        kwargs = {}
        kwargs['dimensions'] = 3
        kwargs.update(user_kwargs)
        super().__init__(*args, **kwargs)

# References

[1] Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.

In [None]:
#hide
import hypothesis.strategies as st
from hypothesis import given, settings

In [None]:
#hide
st_n_channels = st.integers(min_value=1, max_value=50)
st_input_shape = st.tuples(
    st.integers(min_value=6, max_value=50),
    st.integers(min_value=6, max_value=50),
    st.integers(min_value=6, max_value=50)
)
st_n_output_classes = st.integers(min_value=1, max_value=50)

In [None]:
%%time
#hide
@given(
    n_channels=st_n_channels,
    input_shape=st_input_shape,
    n_output_classes=st_n_output_classes
)
@settings(max_examples=5, deadline=None)
def test_output_shapes_in_3d(n_channels, input_shape, n_output_classes):
    model = UNet3D(
        in_channels=n_channels,
        out_classes=n_output_classes,
        num_encoding_blocks=2,
        verbose=False
    ).eval()

    x = torch.rand(1, n_channels, *input_shape)
    assert model(x).shape ==  (1, n_output_classes, *input_shape)


test_output_shapes_in_3d()

CPU times: user 22.2 s, sys: 81 ms, total: 22.3 s
Wall time: 2.54 s


In [None]:
%%time
#hide
@given(input_shape=st_input_shape)
@settings(max_examples=5, deadline=None)
def test_output_shapes_in_3d(input_shape):
    model = UNet3D(
        num_encoding_blocks=2,
        verbose=False
    ).eval()

    x = torch.rand(1, 7, *input_shape)
    assert model(x).shape ==  (1, 1, *input_shape)


test_output_shapes_in_3d()

CPU times: user 22.1 s, sys: 65.8 ms, total: 22.1 s
Wall time: 2.43 s
