# 3D UNet


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.unet
# export
from fastai.basics import *
from faimed3d.basics import *
from fastai.vision.all import create_body, hook_outputs
from torchvision.models.video import r3d_18
from fastai.vision.models.unet import DynamicUnet, _get_sz_change_idxs

In [None]:
# export
import faimed3d
from faimed3d.layers import *

In [None]:
body_3d = create_body(r3d_18, pretrained = False)

## Dynamic Unet 3D

Fastai's `DynamicUnet` allows construction of a UNet using any pretrained CNN as backbone/encoder. A key module is `nn.PixelShuffle` which allows subpixel convolutions for upscaling in the UNet Blocks. However, `nn.PixelShuffle` is only for 2D images, so in faimed3d `nn.ConvTranspose3d` is used instead. 

In [None]:
# export
class ConvTranspose3D(nn.Sequential):
    "Upsample by 2` from `ni` filters to `nf` (default `ni`), using `nn.ConvTranspose3D`."
    def __init__(self, ni, nf=None, scale=2, blur=False, act_cls=None, norm_type=None, **kwargs):
        super().__init__()
        nf = ifnone(nf, ni)
        layers = [ConvLayer(ni, nf, ndim=3, act_cls=act_cls, norm_type=norm_type, transpose=True, **kwargs)]
      #  layers[0].weight.data.copy_(icnr_init(layers[0].weight.data)) 
        if blur: layers += [nn.ReplicationPad3d((1,0,1,0,1,0)), nn.AvgPool3d(2, stride=1)]
        super().__init__(*layers)

Fastai's `PixelShuffle_ICNR` first performes a convolution to increase the layer size, then applies `PixelShuffle` to resize the image. A special initialization technique is applied to `PixelShuffle`, which can reduce checkerboard artifacts (see https://arxiv.org/pdf/1707.02937.pdf). It is probably not needed for `nn.ConvTranspose3d`

In [None]:
ConvTranspose3D(256, 128)(torch.randn((1, 256, 3, 13, 13))).size()

torch.Size([1, 128, 5, 15, 15])

In [None]:
ConvTranspose3D(256, 128, blur = True)(torch.randn((1, 256, 3, 13, 13))).size()

torch.Size([1, 128, 5, 15, 15])

To work with 3D data, the `UnetBlock` from fastai is adapted, replacing `PixelShuffle_ICNR` with the above created `ConvTranspose3D` and also adapting all conv-layers and norm-layers to the 3rd dimension. As small differences in size may appear, `forward`-func contains a interpolation step, which is also adapted to work with 5D input instead of 4D.  
`UnetBlock3D` receives the lower level features as hooks. In contrast to `fastai` the hooks are a list of tensors with `len(hook) == n_inp`, where `n_inp` is the number of 3D sequences in a 4D dimensional input. **Important** 4D here does not refer to the dimensionality of the Tensor, where 4D would be B x C x H x W but to the dimensionality of the input **before** it was concatenated to a Tensor. 3D means we have one 3D volume with dimensionality (C) x D x H x W and 4D means we have multiple 3D volumes. In medical imaging, this is not rare, as we often want to use information from multiple imaging sequences. 
The information of the different sequences can be assumed to be redundant to some kind, so in the `UnetBlock3D`, first the different feature maps from the hooks are concatenated and then pooled using a 1x1x1 convolutional layer. After this, the class is built similar to the fastai `UnetBlock`

In [None]:
# export
def noop(x): return x

In [None]:
# export
class UnetBlock3D(Module):
    "A quasi-UNet block, using `ConvTranspose3d` for upsampling`."
    @delegates(ConvLayer.__init__)
    def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,
                 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        self.hook = hook
        if len(hook.stored) > 1: 
            self.pool_fm = ConvLayer(x_in_c*len(hook.stored), x_in_c, ks = 1, ndim=3, act_cls=act_cls, norm_type=norm_type, **kwargs)
        else: self.pool_fm = noop
        self.up = ConvTranspose3D(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.bn = BatchNorm(x_in_c, ndim=3)
        ni = up_in_c//2 + x_in_c
        nf = ni if final_div else ni//2
        self.conv1 = ConvLayer(ni, nf, ndim=3, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.conv2 = ConvLayer(nf, nf, ndim=3, act_cls=act_cls, norm_type=norm_type,
                               xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = act_cls()
        apply_init(nn.Sequential(self.conv1, self.conv2), init)

    def forward(self, up_in):
        
        s = self.pool_fm(torch.cat(self.hook.stored, 1))
        up_out = self.up(up_in)
        ssh = s.shape[-3:]
        if ssh != up_out.shape[-3:]:
            up_out = F.interpolate(up_out, s.shape[-3:], mode='nearest')
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))
        

The output size of the last Unet-Block can be slightly different than the original input size, so one of the lasts steps in `DynamicUnet` is `ResizeToOrig` which is also adapted to work with 5D instead of 4D input images. 

In [None]:
# export
class ResizeToOrig(Module):
    "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
    def __init__(self, mode='nearest'): self.mode = mode
    def forward(self, x):
        if x.orig.shape[-3:] != x.shape[-3:]:
            x = F.interpolate(x, x.orig.shape[-3:], mode=self.mode)
        return x

`SequentialEx` does not allow to pass more than one item to `forward`, so it is subclassed to allow tuples. Also the input needs to pass the first two blocks outside of the loop, as inputs are a tuple/list at this state. Block two concatenates the input to a single tensor, which can then be passed to the loop.

In [None]:
# export
class SequentialEx4D(SequentialEx):
    "Like `SequentialEx`, but handels orig data differently and allows to pass a tuple/list as input"
    def forward(self, *inputs):
        # can't assign attribute to tuple/list, so passing through encoder outside of loop
        res = self.layers[0](tuple(inputs)) # encoder
        res = self.layers[1](tuple(res)) # concat, after this res is not a list/tuple anymore
        for l in self.layers[2:]:
            res.orig = inputs[0] 
            nres = l(res)
            # We have to remove res.orig to avoid hanging refs and therefore memory leaks
            res.orig, nres[0].orig = None, None
            res = nres
        return res

`faimed3d` needs to support 4D data, that is multiple 3D inputs. The reason a radiologist uses multiple sequences is that some information is only present in a certain sequence. For example, to make the diagnosis of stroke, one needs a strong hyperintense signal in the DWI and also a corresponding hypointense signal on the ADC map. It is (nearly) impossible to make a diagnosis only from one sequence. When we build a model, we also want it to have access to all relevant information.  
The way fastai handels mulitple inputs is to store them in a tuple. So if we have two 3D volumes as input and one mask as target, the batch will be as follows:  
(TensorDicom3D of size B x 3 x D x H W, TensorDicom3D of size B x 3 x D x H x W, TensorMask3D of size B x 3 x D x H x W).
`faimed3d` assumes, that all sequences are of roughly the same orientation (e.g. all axial) and also of the same region. So, information in the sequences can be assumed to be redundant to some extent and likely one does not need an encoder for each sequence and can re-use the weights of one encoder for all images. This approach saves a lot of memory, but makes the training longer. However, it might be beneficial to still have some different weights for each sequence. For this reason, `faimed3d` splits a given encoder into it's stem and main body and duplicated the stem according to the number of inputs using the `MultiStem` class. 

`DynamicUnet3D` is the main UNet class for `faimed3d` and is very similar to the `fastai` `DynamicUnet`. 
Key differences are the adaption to 3D and 4D inputs. In `fastai` `DynamicUnet` the feature maps are stored as hook, each time the size of the feature maps changes in the encoder. In `faimed3d` we can have multiple inputs but only one encoder and the hooks return a list of tensors. This is adressed by adapting `model_sizes`, `dummy_eval`, adding an extra concat-layer and adapting the `UNetBlock`

In [None]:
# export
class DynamicUnet3D(SequentialEx4D):
    "Create a U-Net from a given architecture."
    def __init__(self, encoder, n_out, img_size, n_inp=1, blur=False, blur_final=True, self_attention=False,
                 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
                 init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        
        encoder = Arch4D(encoder, n_inp)        
        sizes = model_sizes_4d(encoder, size=img_size, n_inp=n_inp) # return sizes * n_inp
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
        
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval_4d(encoder, img_size, n_inp)
        x = [x_.detach() for x_ in x]
        ni = sizes[-1][1]

        middle_conv = nn.Sequential(ConvLayer(ni*n_inp, ni*2, act_cls=act_cls, norm_type=norm_type, ndim = len(img_size), **kwargs),
                                    ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, ndim = len(img_size), **kwargs)).eval()
        
        concat = Concat(ni*n_inp, ndim = len(img_size))
        
        x = middle_conv(concat(x))
        
        layers = [encoder, concat, middle_conv]
        
        
        for i,idx in enumerate(sz_chg_idxs):
            not_final = i!=len(sz_chg_idxs)-1
            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i==len(sz_chg_idxs)-3)
            
            unet_block = UnetBlock3D(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
                                     act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if img_size != sizes[0][-3:]: layers.append(ConvTranspose3D(ni))
        layers.append(ResizeToOrig())
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, ndim = 3, **kwargs))
        layers += [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, ndim = 3, **kwargs)]
        apply_init(nn.Sequential(layers[3], layers[-2]), init)
        #apply_init(nn.Sequential(layers[2]), init)
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

In [None]:
unet3d = DynamicUnet3D(body_3d, 2, (10,50,50))

In [None]:
unet3d(torch.rand(2, 3, 10, 50, 50)).size()

torch.Size([2, 2, 10, 50, 50])

In [None]:
unet4d = DynamicUnet3D(body_3d, 2, (10,50,50), 4)

In [None]:
unet4d(torch.rand(2, 3, 10, 50, 50), torch.rand(2, 3, 10, 50, 50), torch.rand(2, 3, 10, 50, 50), torch.rand(2, 3, 10, 50, 50)).size()

torch.Size([2, 2, 10, 50, 50])

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_preprocessing.ipynb.
Converted 03_transforms.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_layers.ipynb.
Converted 06_learner.ipynb.
Converted 06a_models.alexnet.ipynb.
Converted 06b_models.resnet.ipynb.
Converted 06d_models.unet.ipynb.
Converted 06f_models.losses.ipynb.
Converted 07_callback.ipynb.
Converted index.ipynb.
