# Hybrid DeepLabV3

In [2]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [5]:
# default_exp models.deeplab
# export 
from fastai.basics import *
from fastai.vision.all import create_body, hook_outputs
from fastai.vision.models.unet import _get_sz_change_idxs
from faimed3d.basics import *
from faimed3d.layers import *
from faimed3d.models.unet import ResizeToOrig
from faimed3d.models.deeplab import ASPP, ASPPPooling


In [27]:
# export
class HybridDeepLabDecoder(Module):
    "Decoder Block for DynamicDeeplab"
    def __init__(self, ni, low_lvl_ni, hook, n_out, norm_type=None, 
                 act_cls=defaults.activation, ps=0.5):
        self.hook = hook
        
        # segmentation
        self.low_lvl_conv = ConvLayer(low_lvl_ni, low_lvl_ni//2, ks=1, ndim=3, bias=False, 
                                      norm_type=norm_type, act_cls=act_cls)

        self.last_conv = nn.Sequential(
                ConvLayer(ni+low_lvl_ni//2, ni, ks=3, ndim=3, stride=1, padding=1, bias=False, 
                          norm_type=norm_type, act_cls=act_cls), 
                nn.Dropout(ps),
                ConvLayer(ni, ni, ks=3, ndim=3, stride=1, padding=1, bias=False, 
                          norm_type=norm_type, act_cls=act_cls), 
                nn.Dropout(ps/5),
                nn.Conv3d(ni, n_out, kernel_size=1, stride=1))
              
        # classification
        self.clf = nn.Sequential(
                    LinBnDrop(n_in=ni, n_out=ni//2, p = ps),
                    act_cls(), 
                    LinBnDrop(n_in=ni//2, n_out=2, p = ps//2)
        )


    def forward_seg(self, x):
        x_clf = self.clf(x)
        
        s = self.low_lvl_conv(sum(self.hook.stored))
        ssh = s.shape[-3:]
        if ssh != x.shape[-3:]:
            x = F.interpolate(x, size=ssh, mode='nearest')
        x = torch.cat((x, s), dim=1)
        x_seg = self.last_conv(x)
        return x_seg, x_clf
    


In [30]:
class HybridResizeToOrig(Module):
    "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
    def __init__(self, mode='nearest'): self.mode = mode
    def forward(self, x_seg, x_clf):
        if x_seg.orig.shape[-3:] != x_seg.shape[-3:]:
            x_seg = F.interpolate(x_seg, x_seg.orig.shape[-3:], mode=self.mode)
        return x_seg, x_clf

In [43]:
# export
class DynamicDeepLab(SequentialEx):
    "Build DeepLab with different encoders"
    def __init__(self, encoder, n_out, img_size, y_range=None, 
                       act_cls=defaults.activation, norm_type=NormType.Batch, **kwargs):
        
        sizes = model_sizes(encoder, size=img_size)
        sz_chg_idxs = list(_get_sz_change_idxs(sizes))
        self.sfs = hook_outputs(encoder[sz_chg_idxs[1]], detach=False)
        x = dummy_eval(encoder, img_size).detach()
        ni = sizes[-1][1]
        nf = ni//4
        dilations=[1, 12, 24, 36] if ni > 512 else [1, 6, 12, 18]
        aspp = ASPP(ni=ni, nf=nf, dilations=dilations, norm_type=norm_type, act_cls=act_cls).eval()
        x = aspp(x)
        decoder = HybridDeepLabDecoder(ni=nf, low_lvl_ni=sizes[sz_chg_idxs[1]][1], hook=self.sfs, n_out=n_out, 
                                       norm_type=norm_type, act_cls=act_cls).eval()
        x = decoder(x)
        self.layers = nn.ModuleList([encoder, aspp, decoder, HybridResizeToOrig])

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()        

In [44]:
from torchvision.models.video import r3d_18

In [45]:
body_3d = create_body(r3d_18, pretrained = False)

In [46]:
m = DynamicDeepLab(body_3d, 2, (20, 64, 64))
m(torch.randn(1, 3, 20, 64, 64)).shape

NotImplementedError: 

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_preprocessing.ipynb.
Converted 03_transforms.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_layers.ipynb.
Converted 06_learner.ipynb.
Converted 06a_models.alexnet.ipynb.
Converted 06b_models.resnet.ipynb.
Converted 06c_model.efficientnet.ipynb.
Converted 06d_models.unet.ipynb.
Converted 06e_models.deeplabv3.ipynb.
Converted 06f_models.losses.ipynb.
Converted 07_callback.ipynb.
Converted index.ipynb.
