# DeepLabV3+
> 3D implementation of DeepLabV3 

In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.deeplab
# export 
from fastai.basics import *
from fastai.vision.all import create_body, hook_outputs
from fastai.vision.models.unet import _get_sz_change_idxs
from faimed3d.basics import *
from faimed3d.layers import *
from faimed3d.models.unet import AddItems, SequentialEx4D, ResizeToOrig

## DeepLabV3+

Implementation of DeepLabV3+ for 3D. Translates the 2D version from https://github.com/giovanniguidi/deeplabV3-PyTorch to 3D. Adds the functionality to allow mulitple encoders, similar to DynamicUnet. However, works probably best with larger encoders, such as ResNet50. 

### ASPP


In [None]:
# export
class ASPPPooling(nn.Sequential):
    def __init__(self, ni, nf, norm_type=None, act_cls=defaults.activation):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool3d(1),
            ConvLayer(ni=ni, nf=nf, ks=1, ndim=3, bias=False, norm_type=norm_type, act_cls=act_cls)
        )

    def forward(self, x):
        size = x.shape[-3:]
        for module in self:
            x = module(x)
        return F.interpolate(x, size=size, mode='trilinear', align_corners=False)


class ASPP(SequentialEx):
    def __init__(self, ni, dilations, nf, norm_type=None, act_cls=defaults.activation, ps=0.5):

        conv_layers = [ConvLayer(ni=ni, nf=nf, ks=1, bias=False, ndim=3, norm_type=norm_type, act_cls=act_cls)]

        dilations = tuple(dilations)
        for dilation in dilations:
            conv_layers.append(ConvLayer(ni=ni, nf=nf, ndim=3, dilation=dilation, padding=dilation, 
                                    norm_type=norm_type, act_cls=act_cls))
            
        pooling = ASPPPooling(ni=ni, nf=nf, norm_type=norm_type, act_cls=act_cls)

        self.layers = nn.ModuleList([*conv_layers, pooling])

        self.project = nn.Sequential(
            ConvLayer(ni=len(self.layers)*nf, nf=nf, ks=1, bias=False, ndim=3, 
            norm_type=norm_type, act_cls=act_cls),
            nn.Dropout(ps))

    def forward(self, x):
        res = [module(x) for module in self.layers]
        res = torch.cat(res, dim=1)
        return self.project(res)

In [None]:
ASPP(ni=2048, dilations=[1, 6, 12, 18], nf=256, norm_type=NormType.Batch)(torch.randn(10, 2048, 1, 3, 3)).size()

torch.Size([10, 256, 1, 3, 3])

In [None]:
ASPP(ni=2048, dilations=[1, 12, 24, 36], nf=256, norm_type=NormType.Batch)(torch.randn(10, 2048, 1, 3, 3)).size()

torch.Size([10, 256, 1, 3, 3])

In [None]:
# export
class DeepLabDecoder(Module):
    def __init__(self, ni, low_lvl_ni, hook, n_out, norm_type=None, 
                 act_cls=defaults.activation, ps=0.5):
        self.hook = hook
        
        self.low_lvl_conv = ConvLayer(low_lvl_ni, low_lvl_ni//2, ks=1, ndim=3, bias=False, 
                                      norm_type=norm_type, act_cls=act_cls)

        self.last_conv = nn.Sequential(
                ConvLayer(ni+low_lvl_ni//2, ni, ks=3, ndim=3, stride=1, padding=1, bias=False, 
                          norm_type=norm_type, act_cls=act_cls), 
                nn.Dropout(ps),
                ConvLayer(ni, ni, ks=3, ndim=3, stride=1, padding=1, bias=False, 
                          norm_type=norm_type, act_cls=act_cls), 
                nn.Dropout(ps/5),
                nn.Conv3d(ni, n_out, kernel_size=1, stride=1))

    def forward(self, x):
        s = self.low_lvl_conv(sum(self.hook.stored))

        ssh = s.shape[-3:]
        if ssh != x.shape[-3:]:
            x = F.interpolate(x, size=ssh, mode='nearest')
        x = torch.cat((x, s), dim=1)
        return self.last_conv(x)


In [None]:
# export
class DynamicDeepLab(SequentialEx4D):
    def __init__(self, encoder, n_out, img_size, n_inp=1, y_range=None, 
                       act_cls=defaults.activation, norm_type=NormType.Batch, **kwargs):
        
        encoder = Arch4D(encoder, n_inp)
        sizes = model_sizes_4d(encoder, size=img_size, n_inp=n_inp)
        sz_chg_idxs = list(_get_sz_change_idxs(sizes))
        self.sfs = hook_outputs(encoder[sz_chg_idxs[1]], detach=False)
        x = dummy_eval_4d(encoder, img_size, n_inp)
        x = [x_.detach() for x_ in x]
        ni = sizes[-1][1]
        nf = ni//4
        dilations=[1, 12, 24, 36] if ni > 1024 else [1, 6, 12, 18]
        add_items = AddItems()
        aspp = ASPP(ni=ni, nf=nf, dilations=dilations, norm_type=norm_type, act_cls=act_cls).eval()
        
        x = aspp(add_items(x))
        decoder = DeepLabDecoder(ni=nf, low_lvl_ni=sizes[sz_chg_idxs[1]][1], hook=self.sfs, n_out=n_out, 
                                 norm_type=norm_type, act_cls=act_cls).eval()
        x = decoder(x)
        self.layers = nn.ModuleList([encoder, add_items, aspp, decoder, ResizeToOrig()])

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()        

In [None]:
from torchvision.models.video import r3d_18

In [None]:
body_3d = create_body(r3d_18, pretrained = False)

In [None]:
m = DynamicDeepLab(body_3d, 2, (20, 112, 112))
m(torch.randn(1, 3, 20, 112, 112)).shape

torch.Size([1, 2, 20, 112, 112])

In [None]:
m = DynamicDeepLab(body_3d, 2, (20, 112, 112),n_inp=2)
m(torch.randn(1, 3, 20, 112, 112), torch.randn(1, 3, 20, 112, 112)).shape

torch.Size([1, 2, 20, 112, 112])

In [None]:
# hide
from nbdev.export import *
notebook2script()