In [4]:
from tensorflow.keras import layers
from tensorflow.keras import Model

In [3]:
def encoding_base_block(input_tensor,filters,kernal_size,activation='relu',batch_normalization=True):
    x = inputs_tensor
    x = layers.Conv2D(filters=filters,kernal_size=kernal_size,padding='same',kernal_initializer='he_normal')(x)
    
    if batch_normalization:
        x = layers.BatchNormalization()(x)
    x = layers.Activation(activation)(x)
    x = layers.Conv2D(filters=filters,kernal_size=kernal_size,padding='same',kernal_initializer='he_normal')(x)
    
    if batch_normalization:
        x = layers.BatchNormalization()(x)
    x = layers.Activation(activation)(x)
    return x


    
    


In [6]:
def generate_unet(image_tensor,n_filter=16,drop_out=0.1,batch_normalization =True):
    
    x = image_tensor
    
    # encoding/reduction section
    
    conv1 = encoding_base_block(input_tensor = x,filters=n_filters*1,kernal_size=(2,2),batch_normalization=batch_normalization)
    pool1 = layers.MaxPool2D((2, 2))(conv1)
    pool1 = layers.Dropout(dropout)(pool1)
    
    conv2 = encoding_base_block(input_tensor = pool1,filters=n_filter*2,kernal_size(2,2),batch_normalization=batch_normalization)
    pool2 = layers.MaxPool2D((2,2))(conv2)
    pool2 = layers.Dropout(dropout)(pool2)
    
    conv3 = encoding_base_block(input_tensor = pool2,filters=n_filter*4,kernal_size(2,2),batch_normalization=batch_normalization)
    pool3 = layers.MaxPool2D((2,2))(conv3)
    pool3 = layers.Dropout(dropout)(pool3)
    
    conv4 = encoding_base_block(input_tensor = pool3,filters=n_filter*8,kernal_size(2,2),batch_normalization=batch_normalization)
    pool4 = layers.MaxPool2D((2,2))(conv4)
    pool4 = layers.Dropout(dropout)(pool4)
    
    conv5 = encoding_base_block(input_tensor=pool4,filters=n_filter*16,kernal_size=(2,2),batch_normalization=batch_normalization)
    
    output = pool1
    
    model = Model(inputs = [image_tensor],outputs = [output])
    
    return model