In [None]:
#default_exp models

In [None]:
#exporti
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from dl4to.models import ConvolutionalBlock

In [None]:
#hide
from nbdev.showdoc import show_doc

# Decoder

In [None]:
#export
class DecodingBlock(nn.Module):
    """
    This class defines a decoding block for the decoder.
    """
    def __init__(
        self,
        in_channels_skip_connection:int, # The number of input channels from the skip connections of the encoder.
        dimensions:int, # The number of dimensions to consider. Possible options are 2 and 3.
        upsampling_type:str, # The type of upsampling to use.
        normalization:str, # The type of normalization to use. Possible options include "batch", "layer" and "instance".
        preactivation:bool, # Whether to use preactivations.
        residual:bool=False, # Whether the decoder should be a residual network.
        use_padding:bool=False, # Whether to use padding.
        padding_mode:str='zeros', # The type of padding to use.
        activation:str='ReLU', # The activation function that should be used.
        dilation:int=None, # The amount of dilation that should be used.
        dropout:float=0., # The dropout rate.
        upsample_recover_orig_size:bool=False, # Whether the original input size of the encoder should be recovered with the decoder output.
    ):
        super().__init__()

        self.residual = residual
        self.upsampling_type = upsampling_type
        self.upsample_recover_orig_size = upsample_recover_orig_size

        if upsampling_type == 'conv':
            if upsample_recover_orig_size:
                print("Ignoring upsample_recover_orig_size=False when using upsampling_type=conv.")

            in_channels = out_channels = 2 * in_channels_skip_connection
            self.upsample = self._get_conv_transpose_layer(dimensions, in_channels, out_channels)
        else:
            self.upsample = self._get_upsampling_layer(upsampling_type)

        in_channels_first = 3 * in_channels_skip_connection
        out_channels = in_channels_skip_connection

        self.conv1 = ConvolutionalBlock(
            dimensions=dimensions,
            in_channels=in_channels_first,
            out_channels=out_channels,
            normalization=normalization,
            preactivation=preactivation,
            use_padding=use_padding,
            padding_mode=padding_mode,
            activation=activation,
            dilation=dilation,
            dropout=dropout
        )

        in_channels_second = out_channels

        self.conv2 = ConvolutionalBlock(
            dimensions=dimensions,
            in_channels=in_channels_second,
            out_channels=out_channels,
            normalization=normalization,
            preactivation=preactivation,
            use_padding=use_padding,
            padding_mode=padding_mode,
            activation=activation,
            dilation=dilation,
            dropout=dropout
        )

        if residual:
            self.conv_residual = ConvolutionalBlock(
                dimensions=dimensions,
                in_channels=in_channels_first,
                out_channels=out_channels,
                kernel_size=1,
                normalization=None,
                activation=None
            )



    def forward(self, 
                skip_connection:list, # A list of `torch.Tensors` that contain the outputs of the skip connections from an encoding block.
                x:torch.Tensor # The input to the decoding block.
               ):
        """
        The forward pass of the decoding block.
        """
        if self.upsampling_type != 'conv':
            if self.upsample_recover_orig_size:
                self.upsample.size = skip_connection.shape[-3:]
            else:
                self.upsample.scale_factor = 2.

        x = self.upsample(x)
        skip_connection = self._center_crop(skip_connection, x)
        x = torch.cat((skip_connection, x), dim=1)

        if self.residual:
            connection = self.conv_residual(x)
            x = self.conv1(x)
            x = self.conv2(x)
            x += connection
        else:
            x = self.conv1(x)
            x = self.conv2(x)

        return x


    def _get_upsampling_layer(self, upsampling_type):
        upsampling_modes = ('nearest', 'linear', 'bilinear', 'bicubic', 'trilinear')

        if upsampling_type not in upsampling_modes:
            message = (f'Upsampling type is {upsampling_type} but should be one of the following: {upsampling_modes}.')
            raise ValueError(message)

        upsample = nn.Upsample(mode=upsampling_type)
        return upsample


    def _get_conv_transpose_layer(self, dimensions, in_channels, out_channels):
        conv_class = getattr(nn, f'ConvTranspose{dimensions}d')
        conv_layer = conv_class(in_channels, out_channels, kernel_size=2, stride=2)
        return conv_layer


    def _center_crop(self, skip_connection, x):
        skip_shape = np.array(skip_connection.shape)
        x_shape = np.array(x.shape)

        crop = skip_shape[2:] - x_shape[2:]
        half_crop = torch.tensor(crop // 2)

        pad = -torch.stack((half_crop, half_crop)).t().flatten()

        skip_connection = F.pad(skip_connection, pad.tolist())
        return skip_connection

In [None]:
show_doc(DecodingBlock.forward)

<h4 id="DecodingBlock.forward" class="doc_header"><code>DecodingBlock.forward</code><a href="__main__.py#L79" class="source_link" style="float:right">[source]</a></h4>

> <code>DecodingBlock.forward</code>(**`skip_connections`**:`list`, **`x`**:`Tensor`)

The forward pass of the decoding block.

||Type|Default|Details|
|---|---|---|---|
|**`skip_connections`**|`list`||A list of `torch.Tensors` that contain the outputs of the skip connections from an encoding block.|
|**`x`**|`Tensor`||The input to the decoding block.|


In [None]:
#export
class Decoder(nn.Module):
    """
    Defines a decoder that can be used for the construction of UNets [1].
    The decoder is a neural network that takes the feature vector from the encoder and decodes it into an output.
    """
    def __init__(
        self,
        in_channels_skip_connection:int, # The number of input channels from the skip connections of the encoder.
        dimensions:int, # The number of dimensions to consider. Possible options are 2 and 3.
        upsampling_type:str, # The type of upsampling to use.
        num_decoding_blocks:int, # The number of decoding blocks.
        normalization:str, # The type of normalization to use. Possible options include "batch", "layer" and "instance".
        preactivation:bool, # Whether to use preactivations.
        residual:bool=False, # Whether the decoder should be a residual network.
        use_padding:bool=False, # Whether to use padding.
        padding_mode:str='zeros', # The type of padding to use.
        activation:str='ReLU', # The activation function that should be used.
        initial_dilation:int=None, # The amount of dilation that should be used in the first encoding block.
        dropout:float=0., # The dropout rate.
        upsample_recover_orig_size:bool=False, # Whether the original input size of the encoder should be recovered with the decoder output.
    ):
        super().__init__()
        upsampling_type = self._fix_upsampling_type(upsampling_type, dimensions)

        self.decoding_blocks = nn.ModuleList()
        self.dilation = initial_dilation

        for _ in range(num_decoding_blocks):
            decoding_block = DecodingBlock(
                in_channels_skip_connection=in_channels_skip_connection,
                dimensions=dimensions,
                upsampling_type=upsampling_type,
                normalization=normalization,
                preactivation=preactivation,
                residual=residual,
                use_padding=use_padding,
                padding_mode=padding_mode,
                activation=activation,
                dilation=self.dilation,
                dropout=dropout,
                upsample_recover_orig_size=upsample_recover_orig_size
            )

            self.decoding_blocks.append(decoding_block)
            in_channels_skip_connection = in_channels_skip_connection // 2

            if self.dilation is not None:
                self.dilation = self.dilation // 2


    def _fix_upsampling_type(self, upsampling_type, dimensions):
        if upsampling_type == 'linear':
            if dimensions == 2:
                upsampling_type = 'bilinear'
            elif dimensions == 3:
                upsampling_type = 'trilinear'

        return upsampling_type


    def forward(self, 
                skip_connections:list, # A list of `torch.Tensors` that contain the outputs of the skip connections from an encoder.
                x:torch.Tensor # The input to the decoder.
               ):
        """
        The forward pass of the decoder.
        """
        zipped = zip(reversed(skip_connections), self.decoding_blocks)

        for skip_connection, decoding_block in zipped:
            x = decoding_block(skip_connection, x)

        return x

In [None]:
show_doc(Decoder.forward)

<h4 id="Decoder.forward" class="doc_header"><code>Decoder.forward</code><a href="__main__.py#L62" class="source_link" style="float:right">[source]</a></h4>

> <code>Decoder.forward</code>(**`skip_connections`**:`list`, **`x`**:`Tensor`)

The forward pass of the decoder.

||Type|Default|Details|
|---|---|---|---|
|**`skip_connections`**|`list`||A list of `torch.Tensors` that contain the outputs of the skip connections from an encoder.|
|**`x`**|`Tensor`||The input to the decoder.|


# References

[1] Falk, Thorsten, et al. "U-Net: deep learning for cell counting, detection, and morphometry." Nature methods 16.1 (2019): 67-70.