In [22]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
from tqdm import tqdm

In [23]:
input_image = tf.ones((1, 572, 572, 3), dtype=tf.float32)

input_image.shape

TensorShape([1, 572, 572, 3])

# Contracting Module

In [41]:
class C_Module(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(C_Module, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size=3, padding='valid')
        self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size=3, padding='valid')
        self.act = tf.keras.layers.Activation('relu')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.act(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        
        return x

# Convolution Module

In [42]:
class conv(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(conv, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size=3, padding='valid')
        self.act = tf.keras.layers.Activation('relu')
        self.bn1 = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.act(x)
        
        return x

# U-Net

In [96]:
class UNet(tf.keras.layers.Layer):
    def __init__(self):
        super(UNet, self).__init__()
        # Contracting Layer
        self.cont1 = C_Module(64)
        self.cont2 = C_Module(128)
        self.cont3 = C_Module(256)
        self.cont4 = C_Module(512)
        self.pool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)
        
        # Bottleneck Layer
        self.bottle1 = conv(1024)
        self.bottle2 = conv(1024)
        self.dropout = tf.keras.layers.Dropout(0.5)
        
        # Expanding Layer
        self.transpose1 = tf.keras.layers.Conv2DTranspose(512, 2, 2, padding='same')
        self.tp_conv1 = C_Module(512)
        self.transpose2 = tf.keras.layers.Conv2DTranspose(256, 2, 2, padding='same')
        self.tp_conv2 = C_Module(256)
        self.transpose3 = tf.keras.layers.Conv2DTranspose(128, 2, 2, padding='same')
        self.tp_conv3 = C_Module(128)
        self.transpose4 = tf.keras.layers.Conv2DTranspose(128, 2, 2, padding='same')
        self.tp_conv4 = C_Module(64)
        
        # Output
        self.out_conv = tf.keras.layers.Conv2D(2, 1, 1, padding='same')
        self.act = tf.keras.layers.Activation('tanh')
        
        
    def call(self, inputs):
        out1 = self.cont1(inputs)
        out1_p = self.pool(out1)
        out2 = self.cont2(out1_p)
        out2_p = self.pool(out2)
        out3 = self.cont3(out2_p)
        out3_p = self.pool(out3)
        out4 = self.cont4(out3_p)
        out4_p = self.pool(out4)
        
        x = self.bottle1(out4_p)
        x = self.bottle2(x)
        x = self.dropout(x)
        
        # crop image 28 x 28
        tp1 = self.transpose1(x)
        diff = int((out4.shape[1] - tp1.shape[1]) / 2)
        diff_pool = out4[:, diff:-diff, diff:-diff, :]
        tp1 = tf.concat([tp1, diff_pool], axis = -1)
        tp1 = self.tp_conv1(tp1)
        
        # crop image 52 x 52
        tp2 = self.transpose2(tp1)
        diff = int((out3.shape[1] - tp2.shape[1]) / 2)
        diff_pool = out3[:, diff:-diff, diff:-diff, :]
        tp2 = tf.concat([tp2, diff_pool], axis = -1)
        tp2 = self.tp_conv2(tp2)
        
        # crop image 100 x 100
        tp3 = self.transpose3(tp2)
        diff = int((out2.shape[1] - tp3.shape[1]) / 2)
        diff_pool = out2[:, diff:-diff, diff:-diff, :]
        tp3 = tf.concat([tp3, diff_pool], axis = -1)
        tp3 = self.tp_conv3(tp3)
        
        # crop image 196 x 196
        tp4 = self.transpose4(tp3)
        diff = int((out1.shape[1] - tp4.shape[1]) / 2)
        diff_pool = out1[:, diff:-diff, diff:-diff, :]
        tp4 = tf.concat([tp4, diff_pool], axis = -1)
        tp4 = self.tp_conv4(tp4)
        
        # Output
        out = self.act(self.out_conv(tp4))
        
        return out

# Model test

In [97]:
layer = UNet()

In [98]:
layer(input_image).shape

TensorShape([1, 388, 388, 2])