In [1]:
%run headers.ipynb

## Unet Construction:

#### Example of UNet structure:

In [2]:
# %%html
# <img src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png" style="width:100%">
# <img src="https://i.ibb.co/zQ75mGw/Screenshot-219.png" style="width:100%">

In [3]:
# %%html
# <img src="https://i.ibb.co/ZWSj0HQ/Screenshot-222.png" style="width:100%">
# <img src="https://i.ibb.co/XxLLSfH/Screenshot-223.png" style="width:100%">

#### We construct encoders and decoders to simply the problem.This decreases the complexity of the problem.

In [4]:
def conv_block(input_, num_filters):
    
#     conv2D_1 = Conv2D(filters = num_filters,kernel_size =  3, padding="same")(input_)
    conv2D_1 = Conv2D(filters = num_filters,kernel_size =  3, kernel_initializer = 'he_normal', padding="same")(input_)
    batch1 = BatchNormalization()(conv2D_1)
    act1 = Activation("relu")(batch1)

#     conv2D_2 = Conv2D(filters = num_filters,kernel_size =  3, padding="same")(act1)
    conv2D_2 = Conv2D(filters = num_filters,kernel_size =  3, kernel_initializer = 'he_normal', padding="same")(act1)
    batch2 = BatchNormalization()(conv2D_2)
    act2 = Activation("relu")(batch2)

    return act2

In [5]:
#green arrow
def gating_signal(input_,num_filters):
    
    x = Conv2D(filters = num_filters,kernel_size =  3,strides = (1, 1), padding='same')(input_)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

In [6]:
def attention_gate(input_,gating_input,num_filters):
    
    shape_conv_inp = keras.int_shape(input_)
    shape_gate = keras.int_shape(gating_input)
    
    #     shape_gate = (shape_conv_inp)/2 (stays true for all padding='same')
    #     we decrease the dim by 2 to sum up.    
    #     2x2 strides to match shape_gate   
    
    conv_inp = Conv2D(num_filters, (3,3), strides=(2, 2), padding='same')(input_)
    
    gating_conv_concat = add([conv_inp,gating_input])
    concat_relu = Activation("relu")(gating_conv_concat)
    
    pixel_weight = Conv2D(1,(1,1),padding='same')(concat_relu)
    sigmoid_pixel_weight = Activation("sigmoid")(pixel_weight)
    
    shape_sigmoid = keras.int_shape(sigmoid_pixel_weight)
    
    upsample_shape_sigmoid = UpSampling2D(size=(shape_conv_inp[1] // shape_sigmoid[1], shape_conv_inp[2] // shape_sigmoid[2]))(sigmoid_pixel_weight)
    
    upsample_psi = Lambda(lambda x, repnum: keras.repeat_elements(x, repnum, axis=3),arguments={'repnum': shape_conv_inp[3]}) (upsample_shape_sigmoid)

    y = multiply([upsample_psi, input_])
    
    result = Conv2D(shape_conv_inp[3], (1, 1), padding='same')(y)
    result_bn = BatchNormalization()(result)
    return result_bn

In [7]:
count=0

In [8]:
def encoder_block(input_, num_filters):
    global count
    count+=1
    print(count)
    conv = conv_block(input_, num_filters)
#     if count==4:
#         drop = Dropout(0.075)(conv)
#         pool = MaxPool2D((2, 2))(drop)
#         count=0
#         return conv, pool
    pool = MaxPool2D((2, 2))(conv)
    drop = Dropout(0.075)(pool)
    return conv, drop

In [9]:
def decoder_block(input_, skip_features, num_filters):
    x = Conv2DTranspose(filters = num_filters,kernel_size = (2, 2), strides=2, padding="same")(input_)
    x = Concatenate()([x, skip_features])
    x = Dropout(0.075)(x)
    x = conv_block(x, num_filters)
    return x

In [10]:
def unet_build(input_shape):
    
    inputs = Input(input_shape)
    
    conv1, pool1 = encoder_block(inputs, 16)
    conv2, pool2 = encoder_block(pool1, 32)
    conv3, pool3 = encoder_block(pool2, 64) 
    conv4, pool4 = encoder_block(pool3, 128) 

    bridge = conv_block(pool4, 256)

    decoder_1 = decoder_block(bridge, conv4, 128)
    decoder_2 = decoder_block(decoder_1, conv3, 64)
    decoder_3 = decoder_block(decoder_2, conv2, 32)
    decoder_4 = decoder_block(decoder_3, conv1, 16)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid") (decoder_4)
#     outputs = Conv2D(1, 1, padding="same") (decoder_4)

    model = Model(inputs, outputs, name="U-Net")
    return model

In [11]:
def attention_unet_build(input_shape):
    
    inputs = Input(input_shape)
    
    #encode
    conv1, pool1 = encoder_block(inputs, 32)
    conv2, pool2 = encoder_block(pool1, 64)
    conv3, pool3 = encoder_block(pool2, 128) 
    conv4, pool4 = encoder_block(pool3, 256) 

    #bridge
    bridge = conv_block(pool4, 512)
    
    #gating,#attention,#decode
    gating_signal_1 = gating_signal(bridge,256)
    attention_1 = attention_gate(conv4,gating_signal_1,256)
    decoder_1 = decoder_block(bridge, attention_1, 256)
    
    #gating,#attention,#decode
    gating_signal_2 = gating_signal(decoder_1,128)
    attention_2 = attention_gate(conv3,gating_signal_2,128)
    decoder_2 = decoder_block(decoder_1, attention_2, 128)
    
    #gating,#attention,#decode
    gating_signal_3 = gating_signal(decoder_2,64)
    attention_3 = attention_gate(conv2,gating_signal_3,64)
    decoder_3 = decoder_block(decoder_2, attention_3, 64)
    
    #gating,#attention,#decode
    gating_signal_4 = gating_signal(decoder_3,32)
    attention_4 = attention_gate(conv1,gating_signal_4,32)
    decoder_4 = decoder_block(decoder_3, attention_4, 32)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid") (decoder_4)
#     outputs = Conv2D(1, 1, padding="same") (decoder_4)

    model = Model(inputs, outputs, name="AU-Net")
    return model

In [12]:
# def unet_build(input_shape):
    
#     inputs = Input(input_shape)

#     conv1, pool1 = encoder_block(inputs, 32)
#     conv2, pool2 = encoder_block(pool1, 64)
#     conv3, pool3 = encoder_block(pool2, 128) 
#     conv4, pool4 = encoder_block(pool3, 256) 

#     bridge = conv_block(pool4, 512)

#     decoder_1 = decoder_block(bridge, conv4, 256)
#     decoder_2 = decoder_block(decoder_1, conv3, 128)
#     decoder_3 = decoder_block(decoder_2, conv2, 64)
#     decoder_4 = decoder_block(decoder_3, conv1, 32)

#     outputs = Conv2D(1, 1, padding="same", activation="sigmoid") (decoder_4)

#     model = Model(inputs, outputs, name="U-Net")
#     return model