# Segmentation Models

Provided here are Keras implementations of deep learning-based segmentation models for biomedical image segmentation. In this notebook, I provide 3D versions of U-Net, U-ResNet, and U-DenseNet.

All U-Net style segmentation architectures consist of a downsampling path (the encoder), a bottleneck, and an upsampling path (the decoder). Within each of these components are sequences of convolution operations, which we package into blocks. The key insight that U-Net introduced was the idea of skip connections. Each layer in the encoder is linked to its corresponding layer in the decoder via a channel-wise concatenation operation. Below is an illustration of a generalized U-net:

<img src ="images/architecture.png">

The main difference between each architecture provided here lies in the design of the block operators. A vanilla U-Net has the simplest design, with each block consisting of two consecutive convolution layers. U-ResNet uses the same two blocks but adds the first convolution block's output to the second convolution block's output. This addition layer is called a residual connection. Finally, the U-DenseNet blocks concatenate all of the previous outputs after each convolution operation. A pointwise convolution (a convolution kernel with side length 1) layer is added as the block's final step to reduce the number of feature maps from all of the concatenation layers. Below is a visual representation of each block design:

<img src ="images/blocks.png">

In [None]:
##### Tensorflow/Keras #####

# Not all of these are necessary here, but I always like to start with thes imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
import keras.backend as K

### U-Net

In [None]:
def UNet(inputShape, numClasses):
    
    # Convolution block for a 3D U-Net
    def Block(x, filters, params):   
        x = layers.Conv3D(filters, **params[0])(x)
        x = layers.Conv3D(filters, **params[0])(x)
        return x

    # Convolution block + downsampling via max pooling
    def TransitionDown(x, filters, params):
        skip = Block(x, filters, params)
        x = layers.MaxPooling3D(pool_size = (2, 2, 2))(skip)
        return skip, x

    # Skip connection + upsampling via transposed convolution + convolution block
    def TransitionUp(x, skip, filters, params):
        x = layers.Conv3DTranspose(filters, **params[1])(x)
        x = layers.concatenate([x, skip])
        x = Block(x, filters, params)
        return x

    # Parameters for each convolution operation
    paramsConv = dict(kernel_size = (3, 3, 3), activation = 'relu', padding = 'same')
    paramsTrans = dict(kernel_size = (2, 2, 2), strides = (2, 2, 2), padding = 'same')
    params = [paramsConv, paramsTrans]
    filters = [32, 64, 128, 256, 512]

    inputs = layers.Input(inputShape)

    # Downsampling path
    skip1, x = TransitionDown(inputs, filters[0], params)
    skip2, x = TransitionDown(x, filters[1], params)
    skip3, x = TransitionDown(x, filters[2], params)
    skip4, x = TransitionDown(x, filters[3], params)

    # Bottleneck
    x = Block(x, filters[4], params)

    # Upsampling path
    x = TransitionUp(x, skip4, filters[3], params)
    x = TransitionUp(x, skip3, filters[2], params)
    x = TransitionUp(x, skip2, filters[1], params)
    x = TransitionUp(x, skip1, filters[0], params)

    # Output convolution
    outputs = layers.Conv3D(numClasses, (1, 1, 1), activation = 'softmax')(x)

    model = Model(inputs = [inputs], outputs = [outputs])
    return model

### U-ResNet

In [None]:
def ResNet(inputShape, numClasses):
    
    # Convolution block for a 3D U-ResNet
    def Block(x, filters, params):   
        x = layers.Conv3D(filters, **params[0])(x)
        y = layers.Conv3D(filters, **params[0])(x)
        
        # Residual connection
        x = layers.Add()([x, y])
        return x

    # Convolution block + downsampling via max pooling
    def TransitionDown(x, filters, params):
        skip = Block(x, filters, params)
        x = layers.MaxPooling3D(pool_size = (2, 2, 2))(skip)
        return skip, x

    # Skip connection + upsampling via transposed convolution + convolution block
    def TransitionUp(x, skip, filters, params):
        x = layers.Conv3DTranspose(filters, **params[1])(x)
        x = layers.concatenate([x, skip])
        x = Block(x, filters, params)
        return x

    # Parameters for each convolution operation
    paramsConv = dict(kernel_size = (3, 3, 3), activation = 'relu', padding = 'same')
    paramsTrans = dict(kernel_size = (2, 2, 2), strides = (2, 2, 2), padding = 'same')
    params = [paramsConv, paramsTrans]
    filters = [32, 64, 128, 256, 512]

    inputs = layers.Input(inputShape)

    # Downsampling path
    skip1, x = TransitionDown(inputs, filters[0], params)
    skip2, x = TransitionDown(x, filters[1], params)
    skip3, x = TransitionDown(x, filters[2], params)
    skip4, x = TransitionDown(x, filters[3], params)

    # Bottleneck
    x = Block(x, filters[4], params)

    # Upsampling path
    x = TransitionUp(x, skip4, filters[3], params)
    x = TransitionUp(x, skip3, filters[2], params)
    x = TransitionUp(x, skip2, filters[1], params)
    x = TransitionUp(x, skip1, filters[0], params)

    # Output convolution
    outputs = layers.Conv3D(numClasses, (1, 1, 1), activation = 'softmax')(x)

    model = Model(inputs = [inputs], outputs = [outputs])
    return model

### U-DenseNet

In [None]:
def DenseNet(inputShape, numClasses):
    
    # Convolution block for a 3D U-DenseNet
    def Block(x, filters, params):
        
        # Concatenate all of the previous outputs after each convolution operation
        x1 = layers.Conv3D(filters, **params[0])(x)
        x2 = layers.concatenate([x, x1])
        
        x3 = layers.Conv3D(filters, **params[0])(x2)
        x4 = layers.concatenate([x, x1, x3])

        # Pointwise convolution to reduce number of feature maps
        x = layers.Conv3D(filters, **params[1])(x4)
        return x

    # Convolution block + downsampling via max pooling
    def TransitionDown(x, filters, params):
        skip = Block(x, filters, params)
        x = layers.MaxPooling3D(pool_size = (2, 2, 2))(skip)
        return skip, x

    # Skip connection + upsampling via transposed convolution + convolution block
    def TransitionUp(x, skip, filters, params):
        x = layers.Conv3DTranspose(filters, **params[2])(x)
        x = layers.concatenate([x, skip])
        x = Block(x, filters, params)
        return x

    # Parameters for each convolution operation
    paramsConv = dict(kernel_size = (3, 3, 3), activation = 'relu', padding = 'same')
    paramsPoint = dict(kernel_size = (1, 1, 1), activation = 'relu', padding = 'same')
    paramsTrans = dict(kernel_size = (2, 2, 2), strides = (2, 2, 2), padding = 'same')
    params = [paramsConv, paramsPoint, paramsTrans]
    filters = [32, 64, 128, 256, 512]

    inputs = layers.Input(inputShape)

    # Downsampling path
    skip1, x = TransitionDown(inputs, filters[0], params)
    skip2, x = TransitionDown(x, filters[1], params)
    skip3, x = TransitionDown(x, filters[2], params)
    skip4, x = TransitionDown(x, filters[3], params)

    # Bottleneck
    x = Block(x, filters[4], params)

    # Upsampling path
    x = TransitionUp(x, skip4, filters[3], params)
    x = TransitionUp(x, skip3, filters[2], params)
    x = TransitionUp(x, skip2, filters[1], params)
    x = TransitionUp(x, skip1, filters[0], params)

    # Output convolution
    outputs = layers.Conv3D(numClasses, (1, 1, 1), activation = 'softmax')(x)

    model = Model(inputs = [inputs], outputs = [outputs])
    return model