# 3D UNet


In [None]:
# hide
import sys
sys.path.append("..")

In [None]:
# default_exp models.unet
# export 
from fastai.basics import *
from fastai.layers import *
import torchvision, torch
from warnings import warn
from faimed3d.models.modules import Sequential_

## Building Blocks

### Convolutional BlocKs

#### DoubleConv
Standard double Convulutional-Block for UNet adapted for 3D.

    1. 3D Convolution with 3**3 Kernel
    2. 3D BatchNorm
    3. ReLu
    4. Repeat 1-3 once

In [None]:
# export
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = Sequential_(
            nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


#### ResBlock

Double Conv-Block changed to Single ResNet-like Block with Bottleneck-layer

In [None]:
# export
class ResBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        if not mid_channels: 
            mid_channels = int(out_channels/2)
        
        self.res_block = Sequential_(
            nn.Conv3d(in_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)            
        )
    
    def forward(self, x):
        return self.res_block(x)    

#### Double ResBlock

Double ResBlock with additional grouping in the middel Layers. 

In [None]:
# export 
class DoubleResBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        
        if not mid_channels: 
            mid_channels = out_channels//2
        
        self.res_block = Sequential_(
            # 1st Block
            nn.Conv3d(in_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = 4),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),      
            nn.ReLU(inplace=True),            
            
            
            # 2nd Block
            nn.Conv3d(mid_channels, out_channels, kernel_size = 1, stride = 1, bias = False, groups = 16),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False, groups = 16),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      
        )
    
    def forward(self, x):
        return self.res_block(x)

### Down Blocks

In [None]:
# export
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = Sequential_(
            nn.MaxPool3d(kernel_size = (2, 2, 2)),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class ResDown(nn.Module):
    """Downscaling with maxpool then resblock"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = Sequential_(
            nn.MaxPool3d(kernel_size = (2, 2, 2)),
            ResBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    
class DoubleResDown(nn.Module):
    """Downscaling with maxpool then double res block"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = Sequential_(
            nn.MaxPool3d(kernel_size = (1, 2, 2)),
            DoubleResBlock(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

### Upscaling Blocks

In [None]:
# export
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()

        # if trilinear, use the normal convolutions to reduce the number of channels
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):

        x1 = self.up(x1)
        # input is CHW
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2, 
                        diffZ // 2, diffZ - diffZ // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        
        x = torch.cat([x2, x1], dim=1)
        
        return self.conv(x)

class ResUp(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()

        # if trilinear, use the normal convolutions to reduce the number of channels
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = ResBlock(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = ResBlock(in_channels, out_channels)


    def forward(self, x1, x2):

        x1 = self.up(x1)
        # input is CHW
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2, 
                        diffZ // 2, diffZ - diffZ // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        
        x = torch.cat([x2, x1], dim=1)
        
        return self.conv(x)
    
class DoubleResUp(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()

        # if trilinear, use the normal convolutions to reduce the number of channels
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
            self.conv = DoubleResBlock(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleResBlock(in_channels, out_channels)


    def forward(self, x1, x2):

        x1 = self.up(x1)
        # input is CHW
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2, 
                        diffZ // 2, diffZ - diffZ // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        
        x = torch.cat([x2, x1], dim=1)
        
        return self.conv(x)

### Out Layers

In [None]:
# export
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
            
    def forward(self, x):
        return self.conv(x)
    
class OutDuoubleRes(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(OutDuoubleRes, self).__init__()
        if mid_channels is None: mid_channels = in_channels // 4
        self.conv =  Sequential_(
            # 1st Block
            nn.Conv3d(in_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(mid_channels, mid_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(mid_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),      
            nn.ReLU(inplace=True),            
            
            
            # 2nd Block
            nn.Conv3d(mid_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm3d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      
        )
            
    def forward(self, x):
        return self.conv(x)

## Models

### Base Classes

UNet architecture as initially described by [Ronneberger et al](https://arxiv.org/abs/1505.04597).

In [None]:
# export
class AbstractUNet3D(nn.Module):
    "Abstract base class of UNet, should be subclassed by all UNets"
    def __init__(self):
        super(AbstractUNet3D, self).__init__()
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


Double UNet as describet by [Debesh et al.](https://arxiv.org/pdf/2006.04868.pdf), adapted for 3D.   
The original architecture as published: 

![Double UNet Architecture](https://raw.githubusercontent.com/DebeshJha/2020-CBMS-DoubleU-Net/master/img/DoubleU-Net.png)

To also work on 3D images, some things have to be adapted. 
  
A pretrained VGG-19, as they did in the paper cannot be use as no pretrained 3D models exist. All parts will thus be build from scratch.  
Network 1 is thus the same as Network 2. 

In [None]:

class AbstractDoubleUNet3D(nn.Module):
    "Abstract base class of UNet, should be subclassed by all UNets"
    def __init__(self):
        super(AbstractDoubleUNet3D, self).__init__()
    
    def Network(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x5 = self.aspp(x5)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return(x, x1, x2, x3, x4)
    
    def multiply(self, x, n_channels): 
        raise NotImplementedError
    
    def Network_1(self, x): 
        pass
    
    def Network_2(self, x, x1a, x2a, x3a, x4a): 
        pass
    
    def concat(self, x1, x2):
        pass
    
    def forward(self, x): 
        out_1, x1a, x2a, x3a, x4a = self.Network_1(x)
        x = self.multiply(out_1, n_channels)
        out_2 = self.Network_2(x, x1a, x2a, x3a, x4a)
        x = self.concat(out_1, out_2)
        return(x)
    
        

        

### 3D UNet

3D version of the original UNet

In [None]:
# export
class UNet3D(AbstractUNet3D):
    def __init__(self, n_channels, n_classes, trilinear=False):
        super(UNet3D, self).__init__()

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if trilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, trilinear)
        self.up2 = Up(512, 256 // factor, trilinear)
        self.up3 = Up(256, 128 // factor, trilinear)
        self.up4 = Up(128, 64, trilinear)
        self.outc = OutConv(64, n_classes)
       

In [None]:
# export
class UResNet3D(AbstractUNet3D):
    def __init__(self, n_channels, n_classes, trilinear=True):
        super(UResNet3D, self).__init__()

        self.inc = DoubleResBlock(n_channels, 64)
        self.down1 = DoubleResDown(64, 128)
        self.down2 = DoubleResDown(128, 256)
        self.down3 = DoubleResDown(256, 512)
        factor = 2 if trilinear else 1
        self.down4 = DoubleResDown(512, 1024 // factor)
        self.up1 = DoubleResUp(1024, 512 // factor, trilinear)
        self.up2 = DoubleResUp(512, 256 // factor, trilinear)
        self.up3 = DoubleResUp(256, 128 // factor, trilinear)
        self.up4 = DoubleResUp(128, 64, trilinear)
        self.outc = OutDuoubleRes(64, n_classes)
    

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01-basics.ipynb.
Converted 02-transforms.ipynb.
Converted 03-datablock.ipynb.
Converted 04-datasets.ipynb.
Converted 05-models-all.ipynb.
Converted 05a-models-modules.ipynb.
Converted 05b-models-unet.ipynb.
Converted 05c-models-losses.ipynb.
Converted 06-callbacks.ipynb.
Converted 06-various-tools.ipynb.
