In [None]:
#default_exp models

# Encoder

In [None]:
#exporti
import torch
import torch.nn as nn

from dl4to.models import ConvolutionalBlock

In [None]:
#hide
from nbdev.showdoc import show_doc

In [None]:
#export
class EncodingBlock(nn.Module):
    """
    This class defines a single encoding block for an encoder.
    """
    def __init__(
        self,
        in_channels:int, # The number of input channels.
        out_channels_first:int, # The number of output channels after the first encoding step.
        dimensions:int, # The number of dimensions to consider. Possible options are 2 and 3.
        normalization:str, # The type of normalization to use. Possible options include "batch", "layer" and "instance".
        pooling_type:str, # The type of pooling to use.
        pooling_kernel_size:int, # The size of the pooling kernel.
        preactivation:bool=False, # Whether to use preactivations.
        is_first_block:bool=False, # Whether this is the first block of an encoder.
        residual:bool=False, # Whether the encoder should be a residual network.
        use_padding:bool=False, # Whether to use padding.
        padding_mode:str='zeros', # The type of padding to use.
        activation:str='ReLU', # The activation function that should be used.
        dilation:int=None, # The amount of dilation that should be used.
        dropout:float=0., # The dropout rate.
    ):
        super().__init__()

        self.preactivation = preactivation
        self.normalization = normalization

        self.residual = residual

        if is_first_block:
            normalization = None
            preactivation = None
        else:
            normalization = self.normalization
            preactivation = self.preactivation

        self.conv1 = ConvolutionalBlock(
            dimensions=dimensions,
            in_channels=in_channels,
            out_channels=out_channels_first,
            normalization=normalization,
            preactivation=preactivation,
            use_padding=use_padding,
            padding_mode=padding_mode,
            activation=activation,
            dilation=dilation,
            dropout=dropout
        )

        if dimensions == 2:
            out_channels_second = out_channels_first
        elif dimensions == 3:
            out_channels_second = 2 * out_channels_first

        self.conv2 = ConvolutionalBlock(
            dimensions=dimensions,
            in_channels=out_channels_first,
            out_channels=out_channels_second,
            normalization=self.normalization,
            preactivation=self.preactivation,
            use_padding=use_padding,
            activation=activation,
            dilation=dilation,
            dropout=dropout
        )

        if residual:
            self.conv_residual = ConvolutionalBlock(
                dimensions=dimensions,
                in_channels=in_channels,
                out_channels=out_channels_second,
                kernel_size=1,
                normalization=None,
                activation=None
            )

        self._set_downsampling_layer(dimensions, pooling_type, kernel_size=pooling_kernel_size)


    def forward(self, 
                x:torch.Tensor # the input to the encoding block.
               ):
        """
        The forward pass of the encoding block.
        Returns a list of `torch.Tensors` that define the outputs of the skip connections, and a `torch.Tensor` that is the output of the encoding block.
        """
        if self.residual:
            connection = self.conv_residual(x)
            x = self.conv1(x)
            x = self.conv2(x)
            x += connection
        else:
            x = self.conv1(x)
            x = self.conv2(x)

        if self.downsample is None:
            return x

        skip_connection = x
        x = self.downsample(x)
        return x, skip_connection


    @property
    def out_channels(self):
        return self.conv2.conv_layer.out_channels


    def _set_downsampling_layer(self, dimensions, pooling_type, kernel_size, stride=2):
        if pooling_type is None:
            self.downsample = None
        else:
            class_name = '{}Pool{}d'.format(pooling_type.capitalize(), dimensions)
            class_ = getattr(nn, class_name)
            self.downsample = class_(kernel_size, stride)

In [None]:
show_doc(EncodingBlock.forward)

<h4 id="EncodingBlock.forward" class="doc_header"><code>EncodingBlock.forward</code><a href="__main__.py#L80" class="source_link" style="float:right">[source]</a></h4>

> <code>EncodingBlock.forward</code>(**`x`**:`Tensor`)

The forward pass of the encoding block.
Returns a `torch:Tensor` that defines the skip connections, and a `torch.Tensor` that is the output of the encoding block.

||Type|Default|Details|
|---|---|---|---|
|**`x`**|`Tensor`||the input to the encoding block.|


In [None]:
#export
class Encoder(nn.Module):
    """
    This class defines an encoder that can be used for the construction of UNets [1]. 
    An encoder is a neural network that takes the input, and outputs a feature vector for each input sample.
    """
    def __init__(
        self,
        in_channels:int, # The number of input channels.
        out_channels_first:int, # The number of output channels after the first encoding step.
        dimensions:int, # The number of dimensions to consider. Possible options are 2 and 3.
        pooling_type:str, # The type of pooling to use.
        num_encoding_blocks:int, # The number of encoding blocks.
        normalization:str, # The type of normalization to use. Possible options include "batch", "layer" and "instance".
        pooling_kernel_size:int, # The size of the pooling kernel.
        preactivation:bool=False, # Whether to use preactivations.
        residual:bool=False, # Whether the encoder should be a residual network.
        use_padding:bool=False, # Whether to use padding.
        padding_mode:str='zeros', # The type of padding to use.
        activation:str='ReLU', # The activation function that should be used.
        initial_dilation:int=None, # The amount of dilation that should be used in the first encoding block.
        dropout:float=0. # The dropout rate.
    ):
        super().__init__()

        if len(pooling_kernel_size) != dimensions:
            raise ValueError('Length of pooling_kernel_size and dimension do not coincide!')

        self.encoding_blocks = nn.ModuleList()
        self.dilation = initial_dilation
        is_first_block = True

        for _ in range(num_encoding_blocks):
            encoding_block = EncodingBlock(
                in_channels=in_channels,
                out_channels_first=out_channels_first,
                dimensions=dimensions,
                normalization=normalization,
                pooling_type=pooling_type,
                preactivation=preactivation,
                is_first_block=is_first_block,
                residual=residual,
                use_padding=use_padding,
                padding_mode=padding_mode,
                activation=activation,
                dilation=self.dilation,
                dropout=dropout,
                pooling_kernel_size=pooling_kernel_size
            )

            is_first_block = False
            self.encoding_blocks.append(encoding_block)

            if dimensions == 2:
                in_channels = out_channels_first
                out_channels_first = in_channels * 2
            elif dimensions == 3:
                in_channels = 2 * out_channels_first
                out_channels_first = in_channels

            if self.dilation is not None:
                self.dilation *= 2


    @property
    def out_channels(self):
        return self.encoding_blocks[-1].out_channels


    def forward(self, 
                x:torch.Tensor # The input of the encoder.
               ):
        """
        The forward pass of the encoder. 
        Returns a list of `torch.Tensors` that define the outputs of the skip connections, and a `torch.Tensor` that is the output of the encoder.
        """
        skip_connections = []

        for encoding_block in self.encoding_blocks:
            x, skip_connection = encoding_block(x)
            skip_connections.append(skip_connection)

        return skip_connections, x

In [None]:
show_doc(Encoder.forward)

<h4 id="Encoder.forward" class="doc_header"><code>Encoder.forward</code><a href="__main__.py#L70" class="source_link" style="float:right">[source]</a></h4>

> <code>Encoder.forward</code>(**`x`**:`Tensor`)

The forward pass of the encoder. 
Returns a `torch:Tensor` that defines the skip connections, and a `torch.Tensor` that is the output of the encoder.

||Type|Default|Details|
|---|---|---|---|
|**`x`**|`Tensor`||The input of the encoder.|


# References

[1] Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.

In [None]:
#hide
import torch
import hypothesis.strategies as st
from hypothesis import given, settings

In [None]:
#hide
st_n_channels = st.integers(min_value=1, max_value=50)

st_input_shape = st.tuples(
    st.integers(min_value=32, max_value=64),
    st.integers(min_value=32, max_value=64),
    st.integers(min_value=32, max_value=64)
)

st_out_channels_first = st.integers(min_value=1, max_value=2)
st_num_encoding_blocks = st.integers(min_value=1, max_value=4)

In [None]:
%%time
#hide

@given(
    n_channels=st_n_channels,
    input_shape=st_input_shape,
    out_channels_first=st_out_channels_first,
    num_encoding_blocks=st_num_encoding_blocks
)
@settings(max_examples=2, deadline=None)
def test_output_shapes_in_3d_without_padding(
    n_channels,
    input_shape,
    out_channels_first,
    num_encoding_blocks
):
    model = Encoder(
        in_channels=n_channels,
        out_channels_first=out_channels_first,
        dimensions=3,
        pooling_type='max',
        num_encoding_blocks=num_encoding_blocks,
        normalization=None,
        pooling_kernel_size=[2, 2, 2],
        use_padding=False,
    ).eval()

    x = torch.rand(1, n_channels, *input_shape)
    output_shape = list(input_shape)

    for _ in range(num_encoding_blocks):
        output_shape[0] = int((output_shape[0] - 4) / 2)
        output_shape[1] = int((output_shape[1] - 4) / 2)
        output_shape[2] = int((output_shape[2] - 4) / 2)

    assert model(x)[1].shape == (1, out_channels_first * 2**num_encoding_blocks, *output_shape)


test_output_shapes_in_3d_without_padding()

CPU times: user 2.48 s, sys: 31.1 ms, total: 2.51 s
Wall time: 354 ms


In [None]:
%%time
#hide

@given(
    n_channels=st_n_channels,
    input_shape=st_input_shape,
    out_channels_first=st_out_channels_first,
    num_encoding_blocks=st_num_encoding_blocks
)
@settings(max_examples=2, deadline=None)
def test_output_shapes_with_padding(
    n_channels,
    input_shape,
    out_channels_first,
    num_encoding_blocks
):
    model = Encoder(
        in_channels=n_channels,
        out_channels_first=out_channels_first,
        dimensions=3,
        pooling_type='max',
        num_encoding_blocks=num_encoding_blocks,
        normalization=None,
        pooling_kernel_size=[2, 2, 2],
        use_padding=True,
    ).eval()

    x = torch.rand(1, n_channels, *input_shape)
    output_shape = list(input_shape)

    output_shape[0] = int(input_shape[0] / 2**num_encoding_blocks)
    output_shape[1] = int(input_shape[1] / 2**num_encoding_blocks)
    output_shape[2] = int(input_shape[2] / 2**num_encoding_blocks)
    assert model(x)[1].shape == (1, out_channels_first * 2**num_encoding_blocks, *output_shape)


test_output_shapes_with_padding()

CPU times: user 2.42 s, sys: 40.3 ms, total: 2.46 s
Wall time: 278 ms
