# 3D UNet


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.unet
# export 
from fastai.basics import *
from fastai.layers import *
import torchvision, torch
from warnings import warn
from faimed3d.models.resnet import *

## Building Blocks


### Convolutional BlocKs
#### DoubleConv
Standard double Convulutional-Block for UNet adapted for 3D.


In [None]:
# export
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


In [None]:
DoubleConv(32, 64)(torch.randn(2, 32, 10, 50, 50)).size()

torch.Size([2, 64, 10, 50, 50])

In [None]:
# export 
class DoubleResBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        if not mid_channels: 
            mid_channels = out_channels//2
        
        self.res_block = nn.Sequential(
            # 1st Block
            nn.Conv3d(in_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = 4),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),      
            nn.ReLU(inplace=True),            
            
            
            # 2nd Block
            nn.Conv3d(mid_channels, out_channels, kernel_size = 1, stride = 1, bias = False, groups = 16),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = 16),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      
        )
    
    def forward(self, x):
        return self.res_block(x)

## Down Block

In [None]:
# export
class ResDown(nn.Module):
    """Downscaling with maxpool then resblock"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(kernel_size = (2, 2, 2)),
            ResBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    
class DoubleResDown(nn.Module):
    """Downscaling with maxpool then double res block"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(kernel_size = (1, 2, 2)),
            DoubleResBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

### Upscaling Block

In [None]:
# export
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()

        # if trilinear, use the normal convolutions to reduce the number of channels
        if trilinear:
            self.up = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
                DoubleConv(in_channels, in_channels // 2)
            )
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        
        x1 = self.up(x1)
        # input is CHW
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2, 
                        diffZ // 2, diffZ - diffZ // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
class ResUp(Up):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__(in_channels, out_channels, trilinear)

        # if trilinear, use the normal convolutions to reduce the number of channels
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = ResBlock(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = ResBlock(in_channels, out_channels)
    
class DoubleResUp(Up):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__(in_channels, out_channels, trilinear)

        # if trilinear, use the normal convolutions to reduce the number of channels
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleResBlock(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleResBlock(in_channels, out_channels)


### Out Layers

In [None]:
# export
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
            
    def forward(self, x):
        return self.conv(x)
    
class OutDuoubleRes(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(OutDuoubleRes, self).__init__()
        if mid_channels is None: mid_channels = in_channels // 4
        self.conv =  nn.Sequential(
            # 1st Block
            nn.Conv3d(in_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),      
            nn.ReLU(inplace=True),            
            
            
            # 2nd Block
            nn.Conv3d(mid_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      
        )
            
    def forward(self, x):
        return self.conv(x)

## Models

### Base Classes

UNet architecture as initially described by [Ronneberger et al](https://arxiv.org/abs/1505.04597).

In [None]:
# export
class AbstractUNet3D(nn.Module):
    "Abstract base class of UNet, should be subclassed by all UNets"
    def __init__(self):
        super(AbstractUNet3D, self).__init__()
    
    def forward(self, input):
        x1, x2, x3, x4, x5 = self.encoder(input)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

### 3D UResNet

3D version of the original UNet with ResNet3D Encoder

In [None]:
# export
class UResNet3D(AbstractUNet3D):
    def __init__(self, backbone, n_channels, n_classes, trilinear=False):
        super(UResNet3D, self).__init__()
        self.encoder=build_backbone(backbone, n_channels=n_channels, output_stride=16, BatchNorm=nn.BatchNorm3d)
        
        ## BasicBlock leads to smaller channelsize, how to correct this wisely??
        in_channels = 512 if isinstance(self.encoder.layer1[0], BasicBlock3d) else 2048
        
        self.up1 = DoubleResUp(in_channels, in_channels // 2, trilinear)
        in_channels = in_channels // 2
        self.up2 = DoubleResUp(in_channels, in_channels // 2, trilinear)
        in_channels = in_channels // 2
        self.up3 = DoubleResUp(in_channels, in_channels // 2, trilinear)
        in_channels = in_channels // 2
        self.up4 = DoubleResUp(in_channels, in_channels // 2, trilinear)
        self.outc = OutDuoubleRes(in_channels // 2, n_classes)
        

class UResNet3D_2(AbstractUNet3D):
    def __init__(self, n_channels, n_classes, trilinear=True):
        super(UResNet3D_2, self).__init__()

        self.inc = DoubleResBlock(n_channels, 64)
        self.down1 = DoubleResDown(64, 128)
        self.down2 = DoubleResDown(128, 256)
        self.down3 = DoubleResDown(256, 512)
        factor = 2 if trilinear else 1
        self.down4 = DoubleResDown(512, 1024 // factor)
        self.up1 = DoubleResUp(1024, 512 // factor, trilinear)
        self.up2 = DoubleResUp(512, 256 // factor, trilinear)
        self.up3 = DoubleResUp(256, 128 // factor, trilinear)
        self.up4 = DoubleResUp(128, 64, trilinear)
        self.outc = OutDuoubleRes(64, n_classes)
        
    def encoder(self, x): 
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        return x1, x2, x3, x4, x5

In [None]:
UResNet3D(resnet18_3d, n_channels=5, n_classes=2, trilinear=False)(torch.randn(5, 5, 5, 50, 50)).size()

torch.Size([5, 2, 5, 50, 50])

### Double UNet 
as describet by [Debesh et al.](https://arxiv.org/pdf/2006.04868.pdf), adapted for 3D.   
The original architecture as published: 

![Double UNet Architecture](https://raw.githubusercontent.com/DebeshJha/2020-CBMS-DoubleU-Net/master/img/DoubleU-Net.png)

ASPP is not yet implemented

In [None]:
# export
class DoubleUNet3D(nn.Module):
    def __init__(self, backbone, n_channels, n_classes, trilinear=False):
        store_attr()
        super(DoubleUNet3D, self).__init__()
        self.encoder1=build_backbone(backbone, n_channels=self.n_channels, output_stride=16, BatchNorm=nn.BatchNorm3d)
        self.encoder2=build_backbone(backbone, n_channels=self.n_channels, output_stride=16, BatchNorm=nn.BatchNorm3d)
        
        ## BasicBlock leads to smaller channelsize, how to correct this wisely??
        in_channels = 512 if isinstance(self.encoder1.layer1[0], BasicBlock3d) else 2048
        
        self.up1a = Up(in_channels, in_channels // 2, self.trilinear)
        self.up1b = Up(in_channels, in_channels // 2, self.trilinear)
        in_channels = in_channels // 2
        self.up2a = Up(in_channels, in_channels // 2, self.trilinear)
        self.up2b = Up(in_channels, in_channels // 2, self.trilinear)
        in_channels = in_channels // 2
        self.up3a = Up(in_channels, in_channels // 2, self.trilinear)
        self.up3b = Up(in_channels, in_channels // 2, self.trilinear)
        in_channels = in_channels // 2
        self.up4a = Up(in_channels, in_channels // 2, self.trilinear)
        self.up4b = Up(in_channels, in_channels // 2, self.trilinear)
        self.outc1 = OutConv(in_channels // 2, self.n_classes)
        self.outc2 = OutConv(in_channels // 2, self.n_classes)
        self.outc3 = OutConv(self.n_classes * 2, self.n_classes)
        
    def forward(self, input):
        x1a, x2a, x3a, x4a, x5a = self.encoder1(input)
        x = self.up1a(x5a, x4a)
        x = self.up2a(x, x3a)
        x = self.up3a(x, x2a)
        x = self.up4a(x, x1a)
        out1 = self.outc2(x)

        input = torch.stack((torch.mean(out1, 1), )*self.n_channels, 1) * input
        x1b, x2b, x3b, x4b, x5b = self.encoder2(input)
        x = self.up1b(x5b, x4a*x4b)
        x = self.up2b(x, x3a*x3b)
        x = self.up3b(x, x2a*x2b)
        x = self.up4b(x, x1a*x1b)
        out2 = self.outc2(x)
        
        logits = self.outc3(torch.cat((out1, out2), 1))

        return logits



In [None]:
DoubleUNet3D(resnet18_3d, n_channels=5, n_classes=2, trilinear=False)(torch.randn(5, 5, 10, 80, 80)).size()

torch.Size([5, 2, 10, 80, 80])

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_transforms.ipynb.
Converted 03_datablock.ipynb.
Converted 04_datasets.ipynb.
Converted 05a_models.modules.ipynb.
Converted 05b_models.alexnet.ipynb.
Converted 05b_models.deeplabv3.ipynb.
Converted 05b_models.resnet.ipynb.
Converted 05c_models.siamese.ipynb.
Converted 05c_models.unet.ipynb.
Converted 05d_models.losses.ipynb.
Converted 06_callback.ipynb.
Converted 99_tools.ipynb.
Converted index.ipynb.
