<h1> U-Net </h1>

In [2]:
import tensorflow as tf

In [4]:
class Conv_Block_Encoder(tf.keras.layers.Layer):
    
    def __init__(self, filters, kernel_size, activation, padding):
        
        super().__init__()
        
        self.conv1 = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)
        self.conv2 = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)
    
    def call(self, X):
        
        out = self.conv1(X)
        out = self.conv2(out)
        
        return out
    


class Conv_Block_Decoder(tf.keras.layers.Layer):
    
    def __init__(self, filters, kernel_size, activation, padding, pool_size, up_filters, up_kernel, up_stride):
        
        super().__init__()
        
        self.conv1 = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)
        self.conv2 = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)
        self.trans = tf.keras.layers.Conv2DTranspose(filters=up_filters, kernel_size=up_kernel, strides=up_stride, padding=padding)
    
    def call(self, X1, X2):
        
        up = self.trans(X1)
        out = tf.keras.layers.concatenate([up, X2], axis=3)
        out = self.conv1(out)
        out = self.conv2(out)
        
        return out

In [5]:
class UNet(tf.keras.Model):
    
    def __init__(self, start_filter=64, kernel_size=(3,3), pool_size=(2,2), activation='relu', padding='same', up_kernel=(2,2), up_stride=(2,2)):
        
        super().__init__()
        
        self.ground_size_filter = 16*start_filter
        
        self.pool = tf.keras.layers.MaxPooling2D(pool_size)

        self.c1 = Conv_Block_Encoder(filters = start_filter, kernel_size=kernel_size, activation=activation, padding=padding)
        self.c2 = Conv_Block_Encoder(filters = 2*start_filter, kernel_size=kernel_size, activation=activation, padding=padding)
        self.c3 = Conv_Block_Encoder(filters = 4*start_filter, kernel_size=kernel_size, activation=activation, padding=padding)
        self.c4 = Conv_Block_Encoder(filters = 8*start_filter, kernel_size=kernel_size, activation=activation, padding=padding)
        
        self.cg5 = Conv_Block_Encoder(filters = 16*start_filter, kernel_size=kernel_size, activation=activation, padding=padding)
        
        self.d4 = Conv_Block_Decoder(filters = self.ground_size_filter, up_filters=self.ground_size_filter/2, kernel_size = kernel_size, activation=activation, padding=padding, pool_size=pool_size, up_kernel=up_kernel, up_stride=up_stride)
        self.d3 = Conv_Block_Decoder(filters = self.ground_size_filter/2, up_filters=self.ground_size_filter/4, kernel_size = kernel_size, activation=activation, padding=padding, pool_size=pool_size, up_kernel=up_kernel, up_stride=up_stride)
        self.d2 = Conv_Block_Decoder(filters = self.ground_size_filter/4, up_filters=self.ground_size_filter/8, kernel_size = kernel_size, activation=activation, padding=padding, pool_size=pool_size, up_kernel=up_kernel, up_stride=up_stride)
        self.d1 = Conv_Block_Decoder(filters = self.ground_size_filter/8, up_filters=self.ground_size_filter/16, kernel_size = kernel_size, activation=activation, padding=padding, pool_size=pool_size, up_kernel=up_kernel, up_stride=up_stride)
        
     
    
    def _summary(self):
        
        return print('Custon Summary')
        
       
    
    def call(self, X):
    
        
        e1 = self.c1(X)
        e1_pool = self.pool(e1)
        
        e2 = self.c2(e1_pool)
        e2_pool = self.pool(e2)
        
        e3 = self.c3(e2_pool)
        e3_pool = self.pool(e3)
        
        e4 = self.c4(e3_pool)
        e4_pool = self.pool(e4)
        
        
        g5 = self.cg5(e4_pool)
        
        
        d4 = self.d4(g5, e4)
        d3 = self.d3(d4, e3)
        d2 = self.d2(d3, e2)
        d1 = self.d1(d2, e1)
        
        
        outputs = tf.keras.layers.Conv2D(3, (1, 1), activation='sigmoid', padding='same')(d1)
        
        return outputs
    
    
    
    
        
        
    

In [298]:
tf.keras.backend.clear_session()
model = UNet()
X = tf.random.uniform((1, 800, 800, 3))
model(X).shape

TensorShape([1, 800, 800, 3])

In [279]:
model.build((1, 800, 1024, 3))

In [304]:
model.summary(100)

Model: "u_net"
____________________________________________________________________________________________________
Layer (type)                                 Output Shape                            Param #        
conv__block__encoder (Conv_Block_Encoder)    multiple                                38720          
____________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)                 multiple                                0              
____________________________________________________________________________________________________
conv__block__encoder_1 (Conv_Block_Encoder)  multiple                                221440         
____________________________________________________________________________________________________
conv__block__encoder_2 (Conv_Block_Encoder)  multiple                                885248         
____________________________________________________________________________