In [1]:
from keras.models import Model
from keras.layers import BatchNormalization, Input, Conv2D, MaxPooling2D, UpSampling2D, Activation, Concatenate

In [2]:
def conv(x, filters, layername=None):
    """Convolutional layer with batch normalization and ReLU activation.
    
    :param filters: Number of filters
    :return: output of the convolutional block
    """
    x = Conv2D(filters=filters,
               kernel_size=(3, 3),
               padding='same',   
               kernel_initializer='he_uniform',
               name=layername)(x)
    
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

In [3]:
def Unet(filters=16, layers=4, input_shape=(None, None, 1), num_classes=1, activation='sigmoid'):
    """U-Net for semantic segmentation
    
    Reference:
    Ronneberger, O., Fischer, P., & Brox, T. (2015, October).
    U-net: Convolutional networks for biomedical image segmentation.
    In International Conference on Medical image computing and computer-assisted
    intervention (pp. 234-241). Springer, Cham
    
    :param filters: Number of filters in the first layer.
    :param layers: Number of encoder and decoder layers
    :param input_shape: Input shape (width, height, channels)
    :param num_classes: Number of classes specified in last layer
    :param activation: Activation of last layer
    
    :return: U-Net as Keras model
    """
    
    to_concat = []
    
    # Input layer
    inputs = Input(input_shape, name='input_layer')
    x = inputs
    
    # Encoder
    for i in range(layers):
        x = conv(x, filters * 2**i, f'enc_layer{i}_conv1')
        x = conv(x, filters * 2**i, f'enc_layer{i}_conv2')
        to_concat.append(x)
        x = MaxPooling2D()(x)
        
    # Bottleneck
    x = conv(x, filters * 2**(i+1), f'latent_conv')
    
    # Decoder
    for i in range(layers):
        x = UpSampling2D()(x)
        x = Concatenate()([x, to_concat.pop()])
        x = conv(x, filters * 2**(layers-1-i), f'dec_layer{i}_conv1')
        x = conv(x, filters * 2**(layers-1-i), f'dec_layer{i}_conv2')
    
    # Output layer with 1x1 convolution
    outputs = Conv2D(num_classes,
                     kernel_size=(1, 1),
                     padding='same',
                     kernel_initializer='he_uniform',
                     activation=activation,
                     name='final_activation')(x)
    
    return Model(inputs, outputs)

**Unet_v2** below is less fancy, but maybe it is easier to understand (e.g. which layers' outputs are concatenated). <br> The architecture should be exactly the same as **Unet**.

In [4]:
def Unet_v2(filters=16, layers=4, input_shape=(None, None, 1), num_classes=1, activation='sigmoid'):
    """U-Net for semantic segmentation
    
    Reference:
    Ronneberger, O., Fischer, P., & Brox, T. (2015, October).
    U-net: Convolutional networks for biomedical image segmentation.
    In International Conference on Medical image computing and computer-assisted
    intervention (pp. 234-241). Springer, Cham
    
    :param filters: Number of filters in the first layer.
    :param layers: Number of encoder and decoder layers
    :param input_shape: Input shape (width, height, channels)
    :param num_classes: Number of classes specified in last layer
    :param activation: Activation of last layer
    
    :return: U-Net as Keras model
    """
    
    # Input layer
    inputs = Input(input_shape)
    x = inputs
    
    # Encoder
    c1 = conv(x, filters)
    c1 = conv(c1, filters)
    p1 = MaxPooling2D((2, 2))(c1)
    
    c2 = conv(p1, filters * 2)
    c2 = conv(c2, filters * 2)
    p2 = MaxPooling2D((2, 2))(c2)
    
    c3 = conv(p2, filters * 4)
    c3 = conv(c3, filters * 4)
    p3 = MaxPooling2D((2, 2))(c3)
    
    c4 = conv(p3, filters * 8)
    c4 = conv(c4, filters * 8)
    p4 = MaxPooling2D((2, 2))(c4)
    
    # Bottleneck
    c5 = conv(p4, filters * 16)
    
    # Decoder
    d1 = UpSampling2D()(c5)
    d1 = Concatenate()([d1, c4])
    d1 = conv(d1, filters * 8)
    d1 = conv(d1, filters * 8)
    
    d2 = UpSampling2D()(d1)
    d2 = Concatenate()([d2, c3])
    d2 = conv(d2, filters * 4)
    d2 = conv(d2, filters * 4)
    
    d3 = UpSampling2D()(d2)
    d3 = Concatenate()([d3, c2])
    d3 = conv(d3, filters * 2)
    d3 = conv(d3, filters * 2)
    
    d4 = UpSampling2D()(d3)
    d4 = Concatenate()([d4, c1])
    d4 = conv(d4, filters)
    d4 = conv(d4, filters)
    
    # Output layer with 1x1 convolution
    outputs = Conv2D(num_classes,
                     (1, 1),
                     activation=activation,
                     padding='same',
                     kernel_initializer='he_uniform')(d4)
    
    return Model(inputs, outputs)

In [5]:
# Example: Unet with 4 layers 
# for input images of size 32 x 32 with one color channel (grayscale)

unet = Unet(filters=16, layers=4, input_shape=(32, 32, 1))

print('Number of parameters:', unet.count_params())

Number of parameters: 1377121


In [6]:
unet.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_layer (InputLayer)       [(None, 32, 32, 1)]  0           []                               
                                                                                                  
 enc_layer0_conv1 (Conv2D)      (None, 32, 32, 16)   160         ['input_layer[0][0]']            
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 16)  64          ['enc_layer0_conv1[0][0]']       
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 32, 32, 16)   0           ['batch_normalization[0][0]']

 batch_normalization_9 (BatchNo  (None, 4, 4, 128)   512         ['dec_layer0_conv1[0][0]']       
 rmalization)                                                                                     
                                                                                                  
 activation_9 (Activation)      (None, 4, 4, 128)    0           ['batch_normalization_9[0][0]']  
                                                                                                  
 dec_layer0_conv2 (Conv2D)      (None, 4, 4, 128)    147584      ['activation_9[0][0]']           
                                                                                                  
 batch_normalization_10 (BatchN  (None, 4, 4, 128)   512         ['dec_layer0_conv2[0][0]']       
 ormalization)                                                                                    
                                                                                                  
 activatio