In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input,Conv2D, MaxPooling2D, Conv2DTranspose, Concatenate, experimental

In [2]:
def unet_model():
    
    input = Input((572,572,1),name='input')
    conv1 = Conv2D(64,3,activation='relu',name='conv1')(input)
    conv2 = Conv2D(64,3,activation='relu',name='conv2')(conv1)
    
    maxpool1 = MaxPooling2D((2,2),name='maxpool1')(conv2)
    conv3 = Conv2D(128,3,activation='relu',name='conv3')(maxpool1)
    conv4 = Conv2D(128,3,activation='relu',name='conv4')(conv3)
    
    maxpool2 = MaxPooling2D((2,2),name='maxpool2')(conv4)
    conv5 = Conv2D(256,3,activation='relu',name='conv5')(maxpool2)
    conv6 = Conv2D(256,3,activation='relu',name='conv6')(conv5)
    
    maxpool3 = MaxPooling2D((2,2),name='maxpool3')(conv6)
    conv7 = Conv2D(512,3,activation='relu',name='conv7')(maxpool3)
    conv8 = Conv2D(512,3,activation='relu',name='conv8')(conv7) ##
    
    maxpool4 = MaxPooling2D((2,2),name='maxpool4')(conv8)
    conv9 = Conv2D(1024,3,activation='relu',name='conv9')(maxpool4)
    conv10 = Conv2D(1024,3,activation='relu',name='conv10')(conv9)
    
    upsample1 = Conv2DTranspose(512,2,strides=(2,2),name='upsample1')(conv10)
    concat1 = Concatenate(axis=-1,name='concat1')([upsample1,experimental.preprocessing.CenterCrop(56,56,name='crop1')(conv8)])
    conv11 = Conv2D(512,3,activation='relu',name='conv11')(concat1)
    conv12 = Conv2D(512,3,activation='relu',name='conv12')(conv11)

    upsample2 = Conv2DTranspose(256,2,strides=(2,2),name='upsample2')(conv12)
    concat2 = Concatenate(axis=-1,name='concat2')([upsample2,experimental.preprocessing.CenterCrop(104,104,name='crop2')(conv6)])
    conv13 = Conv2D(256,3,activation='relu',name='conv13')(concat2)
    conv14 = Conv2D(256,3,activation='relu',name='conv14')(conv13)
    
    upsample3 = Conv2DTranspose(128,2,strides=(2,2),name='upsample3')(conv14)
    concat3 = Concatenate(axis=-1,name='concat3')([upsample3,experimental.preprocessing.CenterCrop(200,200,name='crop3')(conv4)])
    conv15 = Conv2D(128,3,activation='relu',name='conv15')(concat3)
    conv16 = Conv2D(128,3,activation='relu',name='conv16')(conv15)
    
    upsample4 = Conv2DTranspose(64,2,strides=(2,2),name='upsample4')(conv16)
    concat4 = Concatenate(axis=-1,name='concat4')([upsample4,experimental.preprocessing.CenterCrop(392,392,name='crop4')(conv2)])
    conv17 = Conv2D(64,3,activation='relu',name='conv17')(concat4)
    conv18 = Conv2D(64,3,activation='relu',name='conv18')(conv17)
    
    conv19 = Conv2D(2,1,name='conv19')(conv18)
    model = tf.keras.Model(input,conv19)

    return model

model = unet_model()
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 572, 572, 1) 0                                            
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 570, 570, 64) 640         input[0][0]                      
__________________________________________________________________________________________________
conv2 (Conv2D)                  (None, 568, 568, 64) 36928       conv1[0][0]                      
__________________________________________________________________________________________________
maxpool1 (MaxPooling2D)         (None, 284, 284, 64) 0           conv2[0][0]                      
_______________________________________________________________________________________