# U-Net

Implementation of the fastai [Dynamic U-Net](https://docs.fast.ai/vision.models.unet.html) for threedimensional inputs. 

In [1]:
# hide
import sys
sys.path.append('..')

In [2]:
# default_exp unet
# export
import torch
from torch import nn
import torch.nn.functional as F
from attention_unet.utils import create_body, in_channels, model_sizes, _get_sz_change_idxs, dummy_eval
from attention_unet.fastai_hooks import hook_outputs

In [3]:
from torchvision.models.video import r3d_18

In [4]:
body_3d = create_body(r3d_18, n_in = 3, pretrained = False)

In [5]:
# export
class ConvLayer(nn.Sequential):
    "Create a sequence of convolutional layer, normalization layer and activation function"
    def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 
                 blur=False, act_cls=nn.ReLU, norm_type=None, transpose=False, **kwargs):
        super().__init__()
        if not out_channels: out_channels = in_channels
        if padding is None: padding = ((kernel_size-1)//2 if not transpose else 0)    
        conv = nn.ConvTranspose3d if transpose else nn.Conv3d
        layers = [conv(in_channels, out_channels, kernel_size, stride, padding, **kwargs)]
        if norm_type: layers += [getattr(nn, f'{norm_type}Norm3d')(out_channels)]
        if act_cls: layers += [act_cls()]
        if blur: 
            blur = [nn.ReplicationPad3d((1,0,1,0,1,0)), nn.AvgPool3d(2, stride=1)]
            if transpose: layers += blur
            else: layers = [*blur, *layers]
        super().__init__(*layers)

Create a sequence of convolutional layer, normalization layer and activation function. 

**Args**

| name              | type              | description                                                                                                      |
|-------------------|-------------------|------------------------------------------------------------------------------------------------------------------|
| in_channels       | int               | Number of input channels to convolutional layer                                                                  |
| out_channels      | int               | Number of output channels to convolutional layer                                                                 |
| kernel_size       | int or int-tuple  | Size of convolutional kernel                                                                                     |
| stride            | int or int-tuple  | Stride of the convolutional kernel                                                                               |
| padding           | None, int or int-tuple | Padding during convolution. If `None` padding is estimated automatically                                    |
| act_cls           | nn.Module         | The activation function to be used. Default `nn.ReLU`                                                            |
| norm_type         | str               | The normalization layer to be used                                                                               |
| blur              | bool              | Blur the output after upsampling                                                                                 |
| transpose         | bool              | Make convolutional layer a transposed convolution                                                                |
| kwargs            | -                 | Further arguments passed to the convolutional layer                                                              |


In [6]:
# export
class SpatialAttention(nn.Module):
    "Apply attention gate to input in U-Net Block. Adapted from arxiv.org/abs/1804.03999"
    def __init__(self, up_channels, gated_channels):
        super(SpatialAttention, self).__init__()
        self.conv_up = nn.Conv3d(up_channels, gated_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        self.conv_s = nn.Conv3d(gated_channels, gated_channels, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias = False)
        self.conv_both = nn.Sequential(
            nn.ReLU(), 
            nn.Conv3d(gated_channels, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1)),
            nn.Sigmoid()
        )
           
    def forward(self, up_in, s):
        x = self.conv_up(up_in)
        s = F.interpolate(self.conv_s(s), size=x.shape[2:], mode='trilinear', align_corners=False)
        attn_gate = F.interpolate(self.conv_both(x + s), size=up_in.shape[2:], mode='trilinear', align_corners=False)
        return up_in * attn_gate

Apply attention gate to input in U-Net Block. Adapted from arxiv.org/abs/1804.03999

| name              | type              | description                                                                                                      |
|-------------------|-------------------|------------------------------------------------------------------------------------------------------------------|
| up_channels       | int               | Number of channels in tensor to be upsampled. Attention gate will be applied to this tensor                      |
| gated_channel     | int               | Number of channels in gated input (from skip connection)                                                         |



In [7]:
# export
class UnetBlock3D(nn.Module):
    "Create a U-Net Block, optional with spatial attention"
    def __init__(self, up_channels, gated_channels, hook, final_div=True, blur=False, act_cls=nn.ReLU,
                 norm_type=None, attention_gate=False, **kwargs):
        super(UnetBlock3D, self).__init__()
        self.hook = hook
        self.up = ConvLayer(up_channels, up_channels//2, kernel_size=3, stride=2, blur=blur, act_cls=act_cls, norm_type=norm_type, transpose=True, **kwargs)
        self.bn = getattr(nn, f'{norm_type}Norm3d')(gated_channels)
        self.attention_gate = attention_gate
        if attention_gate: 
            self.spatial_attention = SpatialAttention(up_channels, gated_channels)
            
        in_channels = up_channels//2 + gated_channels
        out_channels = in_channels if final_div else in_channels//2
        
        self.final_conv = nn.Sequential(
            act_cls(),
            ConvLayer(in_channels, out_channels, act_cls=act_cls, norm_type=norm_type, **kwargs),
            ConvLayer(out_channels, out_channels, act_cls=act_cls, norm_type=norm_type, **kwargs)
        )

    def forward(self, up_in):
        s = self.bn(self.hook.stored)
        if self.attention_gate: up_in = self.spatial_attention(up_in, s)
        up_out = self.up(up_in)
        if s.shape[-3:] != up_out.shape[-3:]:
            up_out = F.interpolate(up_out, s.shape[-3:], mode='nearest')
        cat_x = torch.cat([up_out, s], dim=1)
        return self.final_conv(cat_x)

Create a U-Net Block, optional with spatial attention

**Args**

| name              | type              | description                                                                                                      |
|-------------------|-------------------|------------------------------------------------------------------------------------------------------------------|
| up_channels       | int               | Number of channels in tensor to be upsampled                                                                     |
| gated_channels    | int               | Number of channels in gated input (from skip connection)                                                         |
| hook              | hook              | Hooked output from encoder layer (implementation of skip connection)                                             |
| final_div         | bool              | ?????                                                                                                            |
| blur              | bool              | Blur the output after upsampling                                                                                 |
| act_cls           | nn.Module         | The activation function to be used. Default `nn.ReLU`                                                            |
| norm_type         | str               | The normalization layer to be used                                                                               |
| attention_gate    | bool              | Use spatial attention in UNet-Block, adapted from arxiv.org/abs/1804.03999                                       |
| kwargs            | -                 | Further arguments passed to the convolutional layers                                                             |


In [8]:
# export
class DeepSupervision(nn.Module):
    "Create segmentation mask from input as described in arxiv.org/abs/1701.03056  "
    def __init__(self, in_channels, out_channels, img_size, **kwargs):
        super(DeepSupervision, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, **kwargs)
        self.img_size=img_size
        
    def forward(self, x): 
        x = self.conv(x)
        return F.interpolate(x, self.img_size, mode='nearest')

Create segmentation mask from input as described in arxiv.org/abs/1701.03056

**Args**

| name              | type              | description                                                                                                      |
|-------------------|-------------------|------------------------------------------------------------------------------------------------------------------|
| in_channels       | int               | Number of input channels to convolutional layer                                                                  |
| out_channels      | int               | Number of classes                                                                                                |
| img_size          | tuple             | Resolution of the input as tuple with len 3                                                                      |
| kwargs            | -                 | Further arguments passed to the convolutional layers                                                             |



In [9]:
# export
class SqueezeExcitation(nn.Module): 
    "Squeeze-and-Excitation layer as described in arxiv.org/pdf/1709.01507.pdf "
    # ToDo: evaluate final BN - differences for training?
    def __init__(self, in_channels, se_ratio=0.15, act_cls=nn.ReLU, norm_type='None', **kwargs):
        super(SqueezeExcitation, self).__init__()
        
        assert 0 < se_ratio <= 1, f'Expected `se_ratio` to be between 0 and 1 but got {se_ratio}'
        num_squeezed_channels = max(1, int(in_channels * se_ratio))

        
        self.squeeze_expand = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            ConvLayer(in_channels=in_channels, out_channels=num_squeezed_channels, kernel_size=1,
                  act_cls=act_cls, norm_type=None, **kwargs), 
            ConvLayer(in_channels=num_squeezed_channels, out_channels=in_channels, kernel_size=1,
                  act_cls=nn.Sigmoid, norm_type=None,**kwargs), 
            )
        
    def forward(self, x): 
        return self.squeeze_expand(x) * x

Squeeze-and-Excitation layer as described in arxiv.org/pdf/1709.01507.pdf 
**Args**

| name              | type              | description                                                                                                      |
|-------------------|-------------------|------------------------------------------------------------------------------------------------------------------|
| in_channels       | int               | Number of input channels to first convolutional layer                                                            |
| se_ratio          | float             | Squeeze-Expand ratio, should be a float value between 0 and 1                                                    | 
| act_cls           | nn.Module         | The activation function to be used. Default `nn.ReLU`                                                            |
| norm_type         | str               | The normalization layer to be used                                                                               |
| kwargs            | -                 | Further arguments passed to the convolutional layers                                                             |


In [19]:
# export
class DynamicUnet3D(nn.Module):
    
    # To Do Init layers properly
    " Create a U-Net from a given architecture "
    def __init__(self, encoder, n_classes, img_size, act_cls=nn.ReLU, norm_type='Batch', 
                 blur=False, deep_supervision=False, se_middle_conv=False, se_ratio=0.15, attention_gate=False, **kwargs):
        super(DynamicUnet3D, self).__init__()
        
        self.deep_supervision = deep_supervision
        
        # examine encoder and place hooks after each major block
        sizes = model_sizes(encoder, size=img_size) 
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))        
        if not 0 in sz_chg_idxs and attention_gate: sz_chg_idxs += [0] # adds an extra U-Net Block with higher resolution attn_gate
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval(encoder, img_size).detach() 
        
        # construct middle layer
        input_channels = x.size(1)
        if se_middle_conv: 
            middle_conv = SqueezeExcitation(input_channels, se_ratio=se_ratio, act_cls=act_cls, norm_type=norm_type)
        else: 
            middle_conv = nn.Sequential(
                ConvLayer(input_channels, input_channels*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
                ConvLayer(input_channels*2, input_channels, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()  
        
        x = middle_conv(x)
        self.encoder = encoder
        self.middle_conv = middle_conv

        # create upsample blocks/U-Net Block
        layers, ds = [], []
        for i,idx in enumerate(sz_chg_idxs):
            not_final = i!=len(sz_chg_idxs)-1
            up_channels, gated_channels = int(x.shape[1]), int(sizes[idx][1])
            unet_block = UnetBlock3D(up_channels, gated_channels, self.sfs[i], final_div=not_final, attention_gate=attention_gate, blur=blur if not_final else not blur, 
                                     act_cls=act_cls, norm_type=norm_type, **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)
            if self.deep_supervision: ds.append(DeepSupervision(x.shape[1], n_classes, img_size))
                    
        # add another TransposedConv Layer if the Input is still to small
        input_channels = x.size(1)
        if img_size != sizes[0][-3:]: 
            layers.append(ConvLayer(input_channels, input_channels//2, kernel_size=(1,3,3), act_cls=act_cls, norm_type=norm_type, stride=(1,2,2), transpose=True))
            input_channels = input_channels//2
            if self.deep_supervision: ds.append(DeepSupervision(input_channels, n_classes, img_size))
        
        self.layers = nn.ModuleList(layers)
        if self.deep_supervision: self.ds = nn.ModuleList(ds)

        # Construct the final layer
        self.final = ConvLayer(n_classes * len(ds) if self.deep_supervision else input_channels, n_classes, kernel_size=1, stride=1, act_cls=None, norm_type=None,  **kwargs)        
    
    def forward(self, x):
        sz = x.shape[2:]
        x = self.encoder(x)
        x = self.middle_conv(x)
        ds_masks=[]
        for i,l in enumerate(self.layers):
            x = l(x)
            if self.deep_supervision: 
                ds_masks.append(self.ds[i](x))
                
        x = self.final(torch.cat(ds_masks, 1) if self.deep_supervision else x) 
        x = F.interpolate(x, sz, mode='nearest')

        return x

Create a U-Net from a given architecture, based on fastai DynamicUnet

**Args**

| name              | type              | description                                                                                                      |
|-------------------|-------------------|------------------------------------------------------------------------------------------------------------------|
| encoder           | nn.Sequential     | The encoder architecture as created by `utils.create_body`                                                       |
| n_classes         | int               | Number of classes                                                                                                |
| img_size          | tuple             | Resolution of the input as tuple with len 3                                                                      |
| act_cls           | nn.Module         | The activation function to be used. Default `nn.ReLU`                                                            |
| norm_type         | str               | The normalization layer to be used                                                                               |
| blur              | bool              | Blur the output after upsampling                                                                                 |
| deep_supervision  | bool              | Use deep supervision as described in arxiv.org/abs/1701.03056                                                    |   
| se_middle_conv    | bool              | Add channel wise attention to the middle convolution with a Squeeze-and-Excitation layer arxiv.org/pdf/1709.01507.pdf |
| se_ratio          | float             | Squeeze-Expand ratio, should be a float value between 0 and 1                                                    | 
| attention_gate    | bool              | Use spatial attention in UNet-Block, adapted from arxiv.org/abs/1804.03999                                       |


In [21]:
# export
def uresnet_18(in_channels, n_classes, img_size, pretrained=True, **kwargs): 
    arch = create_body(r3d_18, n_in=in_channels, pretrained=pretrained)
    unet = DynamicUnet3D(arch, n_classes=n_classes, img_size=img_size)
    return unet

In [12]:
# hide
from nbdev.export import *
notebook2script()

Converted fastai-hooks.ipynb.
Converted index.ipynb.
Converted unet.ipynb.
Converted utils.ipynb.
