# 3D DeepLabV3+


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.deeplabv3
# export 
from fastai.basics import *
from fastai.layers import *
from warnings import warn
from faimed3d.models.modules import Sequential_
from faimed3d.models.resnet import *

## DeepLabV3+

Implementation of DeepLabV3+ for 3D. 
Credit to https://github.com/giovanniguidi/deeplabV3-PyTorch for providing the code for a 2D implementation of the model. 
Code is not optimized, as maybe synchronized Batchnorm will be implemented in the future, so some fragments of the original 2D code still remain. 

### ASPP

could be usefull to also implement (synchronized BatchNorm)[https://github.com/vacancy/Synchronized-BatchNorm-PyTorch] as the batchsize gets very small in training 3D models  

In [None]:
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation, norm_layer):
        modules = [
            nn.Conv3d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            norm_layer(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels, norm_layer):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-3:]
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='trilinear', align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, dilations, out_channels=256, norm_layer=nn.BatchNorm3d):
        super(ASPP, self).__init__()
        modules = []
        modules.append(nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU()))

        dilations = tuple(dilations)
        for dilation in dilations:
            modules.append(ASPPConv(in_channels, out_channels, dilation, norm_layer))

        modules.append(ASPPPooling(in_channels, out_channels, norm_layer))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

In [None]:
ASPP(in_channels=2048, dilations=[1, 6, 12, 18], out_channels=256, norm_layer=nn.BatchNorm3d)(torch.randn(10, 2048, 1, 3, 3)).size()

torch.Size([10, 256, 1, 3, 3])

In [None]:
ASPP(in_channels=2048, dilations=[1, 12, 24, 36], out_channels=512, norm_layer=nn.BatchNorm3d)(torch.randn(10, 2048, 1, 3, 3)).size()

torch.Size([10, 512, 1, 3, 3])

In [None]:
# export
class Decoder(nn.Module):
    def __init__(self, num_classes, low_level_inplanes, norm_layer=nn.BatchNorm3d):
        super(Decoder, self).__init__()


        self.conv1 = nn.Conv3d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = norm_layer(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(nn.Conv3d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       norm_layer(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Conv3d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       norm_layer(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.1),
                                       nn.Conv3d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()


    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='trilinear', align_corners=True)
        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                torch.nn.init.kaiming_normal_(m.weight)
     #       elif isinstance(m, SynchronizedBatchNorm2d):
     #           m.weight.data.fill_(1)
     #           m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [None]:
build_decoder(4, 256, nn.BatchNorm3d)

Decoder(
  (conv1): Conv3d(256, 48, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (last_conv): Sequential(
    (0): Conv3d(304, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (5): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): Conv3d(256, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  )
)

In [None]:
# export
class DeepLab(nn.Module):
    def __init__(self, backbone=resnet101_3d, output_stride=16, num_classes=4, n_channels=3,
                 norm_layer=nn.BatchNorm3d):
        super(DeepLab, self).__init__()
 
        self.backbone = build_backbone(resnet101_3d, output_stride, norm_layer, n_channels)
        self.aspp = ASPP(in_channels=2048, dilations=[1, 12, 24, 36], out_channels=256, norm_layer=norm_layer)
        self.decoder = Decoder(num_classes, 256, norm_layer)
        self.act = nn.Softmax(dim = 1)

    def forward(self, input):
        _, low_level_feat, _, _, x = self.backbone(input)
        x = self.aspp(x)
        x = self.decoder(x, low_level_feat)
        x = self.act(x)
        x = F.interpolate(x, size=input.size()[2:], mode='trilinear', align_corners=True)

        return x

In [None]:
model = DeepLab(n_channels=4, num_classes=4, norm_layer=IdentityLayer)
input = torch.rand(2, 4, 3, 30, 30)
output = model(input)
print(output.size())

torch.Size([2, 4, 3, 30, 30])


In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_transforms.ipynb.
Converted 03_datablock.ipynb.
Converted 04_datasets.ipynb.
Converted 05a_models.modules.ipynb.
Converted 05b_models.alexnet.ipynb.
Converted 05b_models.deeplabv3.ipynb.
Converted 05b_models.densenet.ipynb.
Converted 05b_models.resnet.ipynb.
Converted 05c_models.siamese.ipynb.
Converted 05c_models.unet.ipynb.
Converted 05d_models.losses.ipynb.
Converted 06_callback.ipynb.
Converted 99_tools.ipynb.
Converted index.ipynb.
