In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, SeparableConv2D, \
     Add, Dense, BatchNormalization, ReLU, MaxPool2D, GlobalAvgPool2D, Cropping2D

In [2]:
def msY_Net(input_shape=(256,256,3)):
    img_width,img_height,img_channels = input_shape
    
    #BUILD THE MODEL
    #input layers / d=detail patches, c=context patches
    inputs_d = tf.keras.layers.Input((img_width, img_height, img_channels))
    inputs_c = tf.keras.layers.Input((img_width, img_height, img_channels))
    #converting to float from int
    s_d = tf.keras.layers.Lambda(lambda x: x/255)(inputs_d)
    s_c = tf.keras.layers.Lambda(lambda x: x/255)(inputs_c)
    
    #CONTRACTION PATH (downscaling) -> Main Encoder - Detail Patch
    c1 = tf.keras.layers.Conv2D(64,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(s_d)
    c1 = tf.keras.layers.Dropout(0.1)(c1)
    c1 = tf.keras.layers.Conv2D(64,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2,2))(c1)

    c2 = tf.keras.layers.Conv2D(128,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(0.1)(c2)
    c2 = tf.keras.layers.Conv2D(128,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2,2))(c2)

    c3 = tf.keras.layers.Conv2D(256,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(0.2)(c3)
    c3 = tf.keras.layers.Conv2D(256,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D((2,2))(c3)

    c4 = tf.keras.layers.Conv2D(512,(3,3), activation='relu', kernel_initializer='he_normal', padding= 'same')(p3)
    c4 = tf.keras.layers.Dropout(0.2)(c4)
    c4 = tf.keras.layers.Conv2D(512,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = tf.keras.layers.MaxPooling2D((2,2))(c4)  #output of 16x16x512


    #CONTRACTION PATH (downscaling) -> Side Encoder - context patch
    c1_c = tf.keras.layers.Conv2D(64,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(s_c)
    c1_c = tf.keras.layers.Dropout(0.1)(c1_c)
    c1_c = tf.keras.layers.Conv2D(64,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c1_c)
    p1_c = tf.keras.layers.MaxPooling2D((2,2))(c1_c)

    c2_c = tf.keras.layers.Conv2D(128,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p1_c)
    c2_c = tf.keras.layers.Dropout(0.1)(c2_c)
    c2_c = tf.keras.layers.Conv2D(128,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c2_c)
    p2_c = tf.keras.layers.MaxPooling2D((2,2))(c2_c)

    c3_c = tf.keras.layers.Conv2D(256,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p2_c)
    c3_c = tf.keras.layers.Dropout(0.2)(c3_c)
    c3_c = tf.keras.layers.Conv2D(256,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c3_c)
    p3_c = tf.keras.layers.MaxPooling2D((2,2))(c3_c)

    c4_c = tf.keras.layers.Conv2D(512,(3,3), activation='relu', kernel_initializer='he_normal', padding= 'same')(p3_c)
    c4_c = tf.keras.layers.Dropout(0.2)(c4_c)
    c4_c = tf.keras.layers.Conv2D(512,(3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c4_c)
    p4_c = tf.keras.layers.MaxPooling2D((2,2))(c4_c) #output of 16x16x512


    #msM - Multi-Scale Merge Block
    msM = tf.keras.layers.Cropping2D(cropping=((4,4),(4,4)), data_format='channels_last')(p4_c) #to crop 4 in ((top,botom),(left,right))
    msM = tf.keras.layers.UpSampling2D(size=(2,2), data_format='channels_last', interpolation='bilinear')(msM)
    
    msM = tf.keras.layers.concatenate([p4,msM])
    cmsM= tf.keras.layers.Conv2D(512, (1,1), activation='relu', kernel_initializer='he_normal', padding='same')(msM)
    
    #EXPANSIVE PATH (upscaling) - Decoder
    u6 = tf.keras.layers.Conv2DTranspose(512,(2,2), strides=(2,2), padding='same')(cmsM)
    u6 = tf.keras.layers.concatenate([u6,c4])
    c6 = tf.keras.layers.Conv2D(512, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = tf.keras.layers.Dropout(0.2)(c6)
    c6 = tf.keras.layers.Conv2D(512, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = tf.keras.layers.Conv2DTranspose(256,(2,2), strides=(2,2), padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7,c3])
    c7 = tf.keras.layers.Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = tf.keras.layers.Dropout(0.2)(c7)
    c7 = tf.keras.layers.Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = tf.keras.layers.Conv2DTranspose(128,(2,2), strides=(2,2), padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8,c2])
    c8 = tf.keras.layers.Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = tf.keras.layers.Dropout(0.1)(c8)
    c8 = tf.keras.layers.Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = tf.keras.layers.Conv2DTranspose(64,(2,2), strides=(2,2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9,c1])
    c9 = tf.keras.layers.Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = tf.keras.layers.Dropout(0.1)(c9)
    c9 = tf.keras.layers.Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = tf.keras.layers.Conv2D(1,(1,1), activation='sigmoid')(c9)
    model = tf.keras.Model(inputs=[inputs_d,inputs_c], outputs=[outputs])
    
    return model

In [3]:
#model 
model = msY_Net(input_shape=(256,256,3))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 256, 256, 3)  0           input_2[0][0]                    
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 256, 256, 64) 1792        lambda_1[0][0]                   
______________________________________________________________________________________________

In [None]:
#export summary for sanity check
from contextlib import redirect_stdout
with open('msY_Net_modelsummary.txt', 'w') as f:
    with redirect_stdout(f):
        model.summary()
#plot model
keras.utils.plot_model(model, "msY_Net_model.png")