In [1]:
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

In [2]:
def initial_block(inp):
    inp1 = inp
    conv = Conv2D(filters=13, kernel_size=3, strides=2, padding='same', use_bias=False, kernel_initializer='he_normal')(inp)
    pool = MaxPool2D(2)(inp1)
    concat = concatenate([conv, pool])
    return concat

In [3]:
def encoder_bottleneck(inp, filters, dilation_rate=2, downsample=False, dilated=False, asymmetric=False, drop_rate=0.1):
    reduce = filters // 4
    down = inp
    kernel_stride = 1
    
    #Downsample
    if downsample:
        kernel_stride = 2
        pad_activations = filters - inp.shape.as_list()[-1]
        down = MaxPool2D(2)(down)
        down = Permute(dims=(1, 3, 2))(down)
        down = ZeroPadding2D(padding=((0, 0), (0, pad_activations)))(down)
        down = Permute(dims=(1, 3, 2))(down)
    
    #1*1 Reduce
    x = Conv2D(filters=reduce, kernel_size=kernel_stride, strides=kernel_stride, padding='same', use_bias=False, kernel_initializer='he_normal')(inp)
    x = BatchNormalization(momentum=0.1)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    #Conv
    if not dilated and not asymmetric:
        x = Conv2D(filters=reduce, kernel_size=kernel_stride, padding='same', kernel_initializer='he_normal')(x)
    elif dilated:
        x = Conv2D(filters=reduce, kernel_size=kernel_stride, padding='same', dilation_rate=dilation_rate, kernel_initializer='he_normal')(x)
    elif asymmetric:
        x = Conv2D(filters=reduce, kernel_size=(1,5), padding='same', use_bias=False, kernel_initializer='he_normal')(x)
        x = Conv2D(filters=reduce, kernel_size=(5,1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization(momentum=0.1)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    #1*1 Expand
    x = Conv2D(filters=filters, kernel_size=kernel_stride, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(momentum=0.1)(x)
    x = SpatialDropout2D(rate=drop_rate)(x)
    
    concat = Add()([x, down])
    concat = PReLU(shared_axes=[1, 2])(concat)
    return concat

In [4]:
def decoder_bottleneck(inp, filters, upsample=False):
    reduce = filters // 4
    up = inp
    
    #Upsample
    if upsample:
        up = Conv2D(filters=filters, kernel_size=1, strides=1, padding='same', use_bias=False, kernel_initializer='he_normal')(up)
        up = UpSampling2D(size=2)(up)
    
    #1*1 Reduce
    x = Conv2D(filters=reduce, kernel_size=1, strides=1, padding='same', use_bias=False, kernel_initializer='he_normal')(inp)
    x = BatchNormalization(momentum=0.1)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    #Conv
    if not upsample:
        x = Conv2D(filters=reduce, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(x)
    else:
        x = Conv2DTranspose(filters=reduce, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization(momentum=0.1)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    #1*1 Expand
    x = Conv2D(filters=filters, kernel_size=1, strides=1, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(momentum=0.1)(x)
    
    concat = Add()([x, up])
    concat = PReLU(shared_axes=[1, 2])(concat)
    
    return concat

In [5]:
def ENet(inp, nclasses=1):
    enc = initial_block(inp)
    enc = BatchNormalization(momentum=0.1)(enc)
    enc = PReLU(shared_axes=[1, 2])(enc)

    #Bottleneck 1.0
    enc = encoder_bottleneck(enc, 64, downsample=True, drop_rate=0.001)

    enc = encoder_bottleneck(enc, 64, drop_rate=0.001)
    enc = encoder_bottleneck(enc, 64, drop_rate=0.001)
    enc = encoder_bottleneck(enc, 64, drop_rate=0.001)
    enc = encoder_bottleneck(enc, 64, drop_rate=0.001)

    #Bottleneck 2.0
    enc = encoder_bottleneck(enc, 128, downsample=True)
    enc = encoder_bottleneck(enc, 128)
    enc = encoder_bottleneck(enc, 128, dilation_rate=2, dilated=True)
    enc = encoder_bottleneck(enc, 128, asymmetric=True)
    enc = encoder_bottleneck(enc, 128, dilation_rate=4, dilated=True)
    enc = encoder_bottleneck(enc, 128)
    enc = encoder_bottleneck(enc, 128, dilation_rate=8, dilated=True)
    enc = encoder_bottleneck(enc, 128, asymmetric=True)
    enc = encoder_bottleneck(enc, 128, dilation_rate=16, dilated=True)

    #Bottleneck 3.0
    enc = encoder_bottleneck(enc, 128)
    enc = encoder_bottleneck(enc, 128, dilation_rate=2, dilated=True)
    enc = encoder_bottleneck(enc, 128, asymmetric=True)
    enc = encoder_bottleneck(enc, 128, dilation_rate=4, dilated=True)
    enc = encoder_bottleneck(enc, 128)
    enc = encoder_bottleneck(enc, 128, dilation_rate=8, dilated=True)
    enc = encoder_bottleneck(enc, 128, asymmetric=True)
    enc = encoder_bottleneck(enc, 128, dilation_rate=16, dilated=True)

    #Bottleneck 4.0
    dec = decoder_bottleneck(enc, 64, upsample=True)
    dec = decoder_bottleneck(dec, 64)
    dec = decoder_bottleneck(dec, 64)

    #Bottleneck 5.0
    dec = decoder_bottleneck(dec, 16, upsample=True)
    dec = decoder_bottleneck(dec, 16)

    dec = Conv2DTranspose(filters=nclasses, kernel_size=3, strides=2, padding='same', activation='softmax')(dec)

    model = Model(inputs=inp, outputs=dec, name='Enet')
    model.save('enet.h5')
    return model

In [6]:
inp = Input(shape=(512, 512, 3))
enet = ENet(inp, 3)