# 3D UNet


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.unet
# export
from fastai.basics import *
from fastai.vision.all import create_body, hook_outputs
from torchvision.models.video import r3d_18
from fastai.vision.models.unet import DynamicUnet, _get_sz_change_idxs

In [None]:
# export
import faimed3d
from faimed3d.layers import *

In [None]:
body_3d = create_body(r3d_18, pretrained = False)

## Dynamic Unet 3D

Fastais `DynamicUnet` allows construction of a UNet using any pretrained CNN as backbone/encoder. A key module is `nn.PixelShuffle` which allows subpixel convolutions for upscaling in the UNet Blocks. However, `nn.PixelShuffle` is only for 2D images, so in faimed3d `nn.ConvTranspose3d` is used instead. 

In [None]:
# export
class ConvTranspose3D(nn.Sequential):
    "Upsample by 2` from `ni` filters to `nf` (default `ni`), using `nn.ConvTranspose3D`."
    def __init__(self, ni, nf=None, scale=2, blur=False, act_cls=None, norm_type=None, **kwargs):
        super().__init__()
        nf = ifnone(nf, ni)
        layers = [ConvLayer(ni, nf, ndim=3, act_cls=act_cls, norm_type=norm_type, transpose=True, **kwargs)]
      #  layers[0].weight.data.copy_(icnr_init(layers[0].weight.data)) 
        if blur: layers += [nn.ReplicationPad3d((1,0,1,0,1,0)), nn.AvgPool3d(2, stride=1)]
        super().__init__(*layers)

Fastai's `PixelShuffle_ICNR` first performes a convolution to increase the layer size, then applies `PixelShuffle` to resize the image. A special initialization technique is applied to `PixelShuffle`, which can reduce checkerboard artifacts (see https://arxiv.org/pdf/1707.02937.pdf). It is probably not needed for `nn.ConvTranspose3d`

In [None]:
ConvTranspose3D(256, 128)(torch.randn((1, 256, 3, 13, 13))).size()

torch.Size([1, 128, 5, 15, 15])

In [None]:
ConvTranspose3D(256, 128, blur = True)(torch.randn((1, 256, 3, 13, 13))).size()

torch.Size([1, 128, 5, 15, 15])

To work with 3D data, the `UnetBlock` form fastai is adapted, replacing `PixelShuffle_ICNR` with the above created `ConvTranspose3D` and also adapting all conv-layers and norm-layers to the 3rd dimension. As small differences in size may appear, `forward`-func contains a interpolation step, which is also adapted to work with 5D input instead of 4D. 

In [None]:
# export
class UnetBlock3D(Module):
    "A quasi-UNet block, using `ConvTranspose3d` for upsampling`."
    @delegates(ConvLayer.__init__)
    def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,
                 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        self.hook = hook
        self.up = ConvTranspose3D(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.bn = BatchNorm(x_in_c, ndim=3)
        ni = up_in_c//2 + x_in_c
        nf = ni if final_div else ni//2
        self.conv1 = ConvLayer(ni, nf, ndim=3, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.conv2 = ConvLayer(nf, nf, ndim=3, act_cls=act_cls, norm_type=norm_type,
                               xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = act_cls()
        apply_init(nn.Sequential(self.conv1, self.conv2), init)

    def forward(self, up_in):
        s = self.hook.stored
        up_out = self.up(up_in)
        ssh = s.shape[-3:]
        if ssh != up_out.shape[-3:]:
            up_out = F.interpolate(up_out, s.shape[-3:], mode='nearest')
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))
        

The output size of the last Unet-Block can be slightly different than the original input size, so one of the lasts steps in `DynamicUnet` is `ResizeToOrig` which is also adapted to work with 5D instead of 4D input images. 

In [None]:
# export
class ResizeToOrig(Module):
    "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
    def __init__(self, mode='nearest'): self.mode = mode
    def forward(self, x):
        if x.orig.shape[-3:] != x.shape[-3:]:
            x = F.interpolate(x, x.orig.shape[-3:], mode=self.mode)
        return x

The `DynamicUnet3D` is very similar to fastai's `DynamicUnet`. Key differences are, that each `ConvLayer` or `BatchNorm` got an extra `ndim=3` argument and `UnetBlock` is replaced by `UnetBlock3D`.

In [None]:
# export
class DynamicUnet3D(SequentialEx):
    "Create a U-Net from a given architecture."
    def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True, self_attention=False,
                 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
                 init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        sizes = model_sizes(encoder, size=img_size)  
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval(encoder, img_size).detach()

        ni = sizes[-1][1]
        middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, ndim = len(img_size), **kwargs),
                                    ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, ndim = len(img_size), **kwargs)).eval()
        x = middle_conv(x)
        layers = [encoder, BatchNorm(ni, ndim = len(img_size)), nn.ReLU(), middle_conv]
        
        for i,idx in enumerate(sz_chg_idxs):
            not_final = i!=len(sz_chg_idxs)-1
            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i==len(sz_chg_idxs)-3)
            unet_block = UnetBlock3D(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
                                     act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if img_size != sizes[0][-3:]: layers.append(ConvTranspose3D(ni))
        layers.append(ResizeToOrig())
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, ndim = 3, **kwargs))
        layers += [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, ndim = 3, **kwargs)]
        apply_init(nn.Sequential(layers[3], layers[-2]), init)
        #apply_init(nn.Sequential(layers[2]), init)
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

In [None]:
m = DynamicUnet3D(body_3d, 2, (10,50,50))

In [None]:
m(torch.rand(2, 3, 10, 50, 50)).size()
m

DynamicUnet3D(
  (layers): ModuleList(
    (0): Sequential(
      (0): BasicStem(
        (0): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): BasicBlock(
          (conv1): Sequential(
            (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv2): Sequential(
            (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (relu): ReLU(inplace=True)
        )
        (1): BasicBlock(
          (conv1): Sequential(

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_preprocessing.ipynb.
Converted 03_transforms.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_layers.ipynb.
Converted 06_learner.ipynb.
Converted 06a_models.alexnet.ipynb.
Converted 06b_models.resnet.ipynb.
Converted 06d_models.unet.ipynb.
Converted 06f_models.losses.ipynb.
Converted 07_callback.ipynb.
Converted index.ipynb.
