# 3D UNet


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.unet
# export 
from fastai.basics import *
from fastai.vision.all import *
from faimed3d.models.resnet import *
from fastai.vision.models.unet import DynamicUnet, _get_sz_change_idxs

In [None]:
path = untar_data(URLs.CAMVID_TINY)
build_dls = partial(SegmentationDataLoaders.from_label_func, path=path, 
                    bs = 8, fnames = get_image_files(path/'images'),
                    label_func = lambda o: path/'labels'/f'{o.stem}_P{o.suffix}',
                    codes = np.loadtxt(path/'codes.txt', dtype=str))

In [None]:
body_3d = create_body(resnet18_3d)
body_2d = create_body(resnet18)

In [None]:
DynamicUnet(body_2d, 2, (100,100))(torch.randn(5, 3, 100, 100)).size()

torch.Size([5, 2, 100, 100])

In [None]:
class DynamicUnet(SequentialEx):
    "Create a U-Net from a given architecture."
    def __init__(self, encoder, n_classes, img_size, blur=False, blur_final=True, self_attention=False,
                 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
                 init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        imsize = img_size
        sizes = model_sizes(encoder, size=imsize)
        
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval(encoder, imsize).detach()

        ni = sizes[-1][1]
        middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, ndim = len(imsize), **kwargs),
                                    ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, ndim = len(imsize), **kwargs)).eval()
        x = middle_conv(x)
        layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]

        for i,idx in enumerate(sz_chg_idxs):
            not_final = i!=len(sz_chg_idxs)-1
            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i==len(sz_chg_idxs)-3)
            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
                                   act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
        layers.append(ResizeToOrig())
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
        layers += [ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
        apply_init(nn.Sequential(layers[3], layers[-2]), init)
        #apply_init(nn.Sequential(layers[2]), init)
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

In [None]:
def model_sizes(m, size=(64,64)):
    "Pass a dummy input through the model `m` to get the various sizes of activations."
    with hook_outputs(m) as hooks:
        _ = dummy_eval(m, size=size)
        return [o.stored.shape for o in hooks]

In [None]:
def dummy_eval(m, size=(64,64)):
    "Evaluate `m` on a dummy input of a certain `size`"
    ch_in = in_channels(m)
    x = one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)
    with torch.no_grad(): return m.eval()(x)

In [None]:
def in_channels(m):
    "Return the shape of the first weight layer in `m`."
    for l in flatten_model(m):
        if getattr(l, 'weight', None) is not None and l.weight.ndim in (4,5):
            return l.weight.shape[1]
    raise Exception('No weight layer')

In [None]:
DynamicUnet(body_3d, 2, (10,100,100))

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [1024, 512, 1, 1], but got 5-dimensional input of size [1, 512, 1, 7, 7] instead

In [None]:
UnetBlock??

[0;31mInit signature:[0m
[0mUnetBlock[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mup_in_c[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mx_in_c[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhook[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfinal_div[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mblur[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mact_cls[0m[0;34m=[0m[0;34m<[0m[0;32mclass[0m [0;34m'torch.nn.modules.activation.ReLU'[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mself_attention[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minit[0m[0;34m=[0m[0;34m<[0m[0mfunction[0m [0mkaiming_normal_[0m [0mat[0m [0;36m0x7f1773681320[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnorm_type[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mks[0m[0;34m=[0m[0;36m3[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstride[0m[0;34m=[0m[0;36m1[