# Import Modules

In [1]:
import tensorflow as tf

from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization, Activation, MaxPooling2D, concatenate, Dropout, Input
from tensorflow.keras.models import Model

# Model

In [2]:
class Blocks:
    
    def __init__(self, n_filters):
        
        self.n_filters = n_filters
        self.n = 1
        self.convLayers = {}
    
    def pool_block(self, input_x, pool_size=(2, 2), dropout=0.5):
        
        x = MaxPooling2D(pool_size=pool_size)(input_x)
        x = Dropout(dropout)(x)
        
        return x
    
    def conv2d_block(self, input_x, kernel_size=(3,3), pad='same', count=True):
        
        if count:

            name = f'conv_{(self.n)}'

        else:
          
            name = f'conv_ePath_{(self.n // 2)}'
        
        x = Conv2D(filters=self.n_filters * self.n, kernel_size=kernel_size, padding=pad)(input_x)
        x = Activation('relu')(x)
        x = BatchNormalization()(x)

        x = Conv2D(filters=self.n_filters * self.n, kernel_size=kernel_size, padding=pad, name=name)(x)
        x = Activation('relu')(x)
        x = BatchNormalization()(x)
        
        if count:

            self.convLayers[name] = x 
            self.n *=2
                    
        return x
    
    def convTrans_block(self, input_x, kernel_size=(3,3), strides=(2, 2), pad='same', dropout=0.5):
       
        assert self.n >= 2, f'n = {self.n}'
        
        self.n //=2
        
        conv_name = f'conv_{self.n // 2}'
        
        x = Conv2DTranspose(filters=self.n_filters * self.n, kernel_size=kernel_size, strides = strides, padding=pad)(input_x)
        x = concatenate([x, self.convLayers[conv_name]])
        x = Dropout(dropout)(x)
        
        return x

In [3]:
SHAPE = (256, 256, 3)

In [4]:
block = Blocks(n_filters=16)

inputs = Input(shape=SHAPE)

# contracting path

x = block.conv2d_block(inputs)
x = block.pool_block(x)

x = block.conv2d_block(x)
x = block.pool_block(x)

x = block.conv2d_block(x)
x = block.pool_block(x)

x = block.conv2d_block(x)
x = block.pool_block(x)

x = block.conv2d_block(x)


# expansive path

x = block.convTrans_block(x)
x = block.conv2d_block(x, count=False)

x = block.convTrans_block(x)
x = block.conv2d_block(x, count=False)

x = block.convTrans_block(x)
x = block.conv2d_block(x, count=False)

x = block.convTrans_block(x)
x = block.conv2d_block(x, count=False)

outputs = Conv2D(1, (1, 1), activation='sigmoid')(x)

In [5]:
model = Model(inputs=[inputs], outputs=[outputs])

model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 16) 448         input_1[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 256, 256, 16) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 256, 256, 16) 64          activation[0][0]                 
_______________________________________________________________________________________