In [None]:
from nbdev import *
%nbdev_default_export models

Cells will be exported to deepflash2.models,
unless a different module is specified after an export flag: `%nbdev_export special.module`


# Models

> Unet models.

In [None]:
%nbdev_hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
%nbdev_export
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
%nbdev_export
class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm, 
                 dropout=0., neg_slope=0.1):
        super(UNetConvBlock, self).__init__()
        block = []

        if dropout>0.:
            block.append(nn.Dropout(p=dropout))
        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))
        block.append(nn.LeakyReLU(negative_slope=neg_slope))


        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))
        block.append(nn.LeakyReLU(negative_slope=neg_slope))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out

In [None]:
%nbdev_export
class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm, 
                 dropout=0., neg_slope=0.1):
        super(UNetUpBlock, self).__init__()
        up_block = []
        if dropout>0.:
            up_block.append(nn.Dropout(p=dropout))
        if up_mode == 'upconv':
            up_block.append(nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2))
        elif up_mode == 'upsample':
            up_block.append(nn.Upsample(mode='bilinear', scale_factor=2))
            up_block.append(nn.Conv2d(in_size, out_size, kernel_size=1))
        if batch_norm:
            up_block.append(nn.BatchNorm2d(out_size))
        up_block.append(nn.LeakyReLU(negative_slope=neg_slope))

        self.up = nn.Sequential(*up_block)
        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

## Unet class

Dropout according to "Bayesian SegNet: Model Uncertainty in Deep Convolutional Encoder-Decoder Architectures for Scene Understanding" https://arxiv.org/pdf/1511.02680.pdf
 - Dropout Central Enc-Dec 

In [None]:
%nbdev_export
class UNet2D(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=2,
        depth=5,
        wf=6,
        padding=False,
        batch_norm=False,
        dropout = 0., 
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super().__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            bn = True if i>0 else False
            do = dropout if i>2 else 0. 
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm = bn, dropout = do)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            bn = True if i>0 else False
            do = dropout if i>2 else 0. 
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm = bn, dropout = do)
            )
            prev_channels = 2 ** (wf + i)
        
        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
        
        #if init_weights:
        #    self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize layer weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)

# Export 

In [None]:
%nbdev_hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted index.ipynb.
