In [10]:
import tensorflow as tf
import keras
from keras import layers
import matplotlib.pyplot as plt


In [11]:
class DoubleConvLayer(layers.Layer):
    
    def __init__(self, n_filters, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")
        self.conv2 = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")
    
    def call(self, inputs):
        out = self.conv1(inputs)
        out = self.conv2(out)
        return out

In [12]:
class DownsampleLayer(layers.Layer):
    
    def __init__(self, n_filters, **kwargs):
        super().__init__(**kwargs)
        self.conv = DoubleConvLayer(n_filters)
        self.pool = layers.MaxPool2D(2)
        self.dropout = layers.Dropout(0.3)
    
    def call(self, inputs):
        large = self.conv(inputs)
        small = self.pool(large)
        small = self.dropout(small)
        return large, small

In [61]:
class UpsampleLayer(layers.Layer):
    
    def __init__(self, n_filters, **kwargs):
        super().__init__(**kwargs)
        self.conv_trans = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")
        self.dropout = layers.Dropout(0.3)
        self.conv = DoubleConvLayer(n_filters)
    
    def call(self, inputs, conv_features):
        out = self.conv_trans(inputs)
        out  = layers.concatenate([out, conv_features])
        out = self.dropout(out)
        out = self.conv(out)
        return out

In [62]:
class UnetModel(keras.Model):
    
    def __init__(self, name='unet', **kwargs):
        super().__init__(name=name, **kwargs)
        
        self.down1 = DownsampleLayer(64)
        self.down2 = DownsampleLayer(128)
        self.down3 = DownsampleLayer(256)
        self.down4 = DownsampleLayer(512)
        
        self.middle_conv = DoubleConvLayer(1024)
        
        self.up1 = UpsampleLayer(512)
        self.up2 = UpsampleLayer(256)
        self.up3 = UpsampleLayer(128)
        self.up4 = UpsampleLayer(64)
        
        self.last_conv = layers.Conv2D(11, 1, padding="same", activation = "softmax")
        
        self.build(input_shape=(None, 128, 128, 3))
        
    def call(self, inputs):
        f1, p1 = self.down1(inputs)
        f2, p2 = self.down2(p1)
        f3, p3 = self.down3(p2)
        f4, p4 = self.down4(p3)
        
        bottleneck = self.middle_conv(p4)
        
        u1 = self.up1(bottleneck, f4)
        u2 = self.up2(u1, f3)
        u3 = self.up3(u2, f2)
        u4 = self.up4(u3, f1)
        
        outputs = self.last_conv(u4)
        return outputs

In [63]:
model = UnetModel()

In [64]:
model.summary()

Model: "unet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 downsample_layer_72 (Downsa  multiple                 38720     
 mpleLayer)                                                      
                                                                 
 downsample_layer_73 (Downsa  multiple                 221440    
 mpleLayer)                                                      
                                                                 
 downsample_layer_74 (Downsa  multiple                 885248    
 mpleLayer)                                                      
                                                                 
 downsample_layer_75 (Downsa  multiple                 3539968   
 mpleLayer)                                                      
                                                                 
 double_conv_layer_166 (Doub  multiple                 1415782

: 

In [2]:
# def double_conv_block(x, n_filters):
#     x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
#     x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
#     return x


# def downsample_block(x, n_filters):
#     f = double_conv_block(x, n_filters)
#     p = layers.MaxPool2D(2)(f)
#     p = layers.Dropout(0.3)(p)
#     return f, p


# def upsample_block(x, conv_features, n_filters):
#     x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
#     x = layers.concatenate([x, conv_features])
#     x = layers.Dropout(0.3)(x)
#     x = double_conv_block(x, n_filters)
#     return x


def build_unet_model():
    inputs = layers.Input(shape=(128,128,3))

    # downsamping
    f1, p1 = downsample_block(inputs, 64)
    f2, p2 = downsample_block(p1, 128)
    f3, p3 = downsample_block(p2, 256)
    f4, p4 = downsample_block(p3, 512)

    # 5 - bottleneck
    bottleneck = double_conv_block(p4, 1024)

    # upsampling
    u6 = upsample_block(bottleneck, f4, 512)
    u7 = upsample_block(u6, f3, 256)
    u8 = upsample_block(u7, f2, 128)
    u9 = upsample_block(u8, f1, 64)

    # outputs
    outputs = layers.Conv2D(11, 1, padding="same", activation = "softmax")(u9)
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")
    return unet_model