In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import *
from tensorflow.keras.layers import *

In [None]:
def UnetModel(input_shape, output_nc, ngf=64,use_bias = False):
  

  input = tf.keras.layers.Input(shape = input_shape)
  
  out1 = conv_batch_act(ngf)(input)

  out2 = tf.keras.layers.Conv2D(ngf*2,kernel_size = 2, strides =2, padding= 'valid',use_bias= use_bias)(out1)
  out2 = conv_batch_act(ngf*2)(out2)
  
  out3 = tf.keras.layers.Conv2D(ngf*4,kernel_size = 2, strides =2, padding= 'valid',use_bias= use_bias)(out2)
  out3 = conv_batch_act(ngf*4)(out3)

  out4 = tf.keras.layers.Conv2D(ngf*8,kernel_size = 2, strides =2, padding= 'valid',use_bias= use_bias)(out3)
  out4 = conv_batch_act(ngf*8)(out4)

  out5 = tf.keras.layers.Conv2D(ngf*16,kernel_size = 2, strides =2, padding= 'valid',use_bias= use_bias)(out4)
  out5 = conv_batch_act(ngf*16)(out5)

  up5 = tf.keras.layers.Conv2DTranspose(ngf*8,
                                    kernel_size=2, strides=2,
                                    padding='valid', use_bias=use_bias)(out5)
  up5 = tf.keras.layers.Concatenate(axis=-1)([out4,up5])

  up4 = conv_batch_act_up(ngf*8,ngf*4,use_bias=use_bias)(up5)   
  up4 = tf.keras.layers.Concatenate(axis=-1)([out3,up4])   

  up3 = conv_batch_act_up(ngf*4,ngf*2,use_bias=use_bias)(up4)    
  up3 = tf.keras.layers.Concatenate(axis=-1)([out2,up3])

  up2 = conv_batch_act_up(ngf*2,ngf,use_bias=use_bias)(up3)   
  up2 = tf.keras.layers.Concatenate(axis=-1)([out1,up2])

  out  = conv_batch_act(ngf)(up2)     
  out = tf.keras.layers.Conv2D(output_nc,kernel_size=1,strides = 1, padding='same',use_bias = True)(out)
  out = tf.keras.layers.Activation('softmax')(out)

  model = tf.keras.models.Model(input,out)

  return model      

                            




In [None]:
def conv_batch_act(inner_nc):

  def call(input):
    x = tf.keras.layers.Conv2D(inner_nc, kernel_size=3, strides=1, padding='same', use_bias=False)(input)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv2D(inner_nc, kernel_size=3, strides=1, padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    return x
  return call

def conv_batch_act_up(inner_nc,outer_nc,use_bias):

  def call(input):
    x = tf.keras.layers.Conv2D(inner_nc, kernel_size=3, strides=1, padding='same', use_bias=False)(input)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv2D(inner_nc, kernel_size=3, strides=1, padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(outer_nc,
                                    kernel_size=2, strides=2,
                                    padding='valid', use_bias=use_bias)(x)

    return x

  return call

In [None]:
# UnetModel(input_shape = (512,512,3), output_nc=3, ngf=32).summary()

In [None]:
import keras.backend as K
class Swish(tf.keras.layers.Layer):
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name, **kwargs)

    def call(self, inputs, **kwargs):
        return tf.nn.swish(inputs)

    def get_config(self):
        config = super().get_config()
        config['name'] = self.name
        return config
def squeeze_excite_block(reduce_ratio=0.25,name_block=None):
  def call(inputs):
    filters = inputs.shape[-1]
    num_reduced_filters= max(1, int(filters * reduce_ratio))
    se = Lambda(lambda a: K.mean(a, axis=[1,2], keepdims=True))(inputs)

    se = Conv2D(
            num_reduced_filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer='he_normal',
            padding='same',
            use_bias=True
        )(se)
    se = Swish()(se)
    se = Conv2D(
            filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer='he_normal',
            padding='same',
            use_bias=True
        )(se)
    se = Activation('sigmoid')(se)
    if name_block is not None:
      out = Multiply(name=name_block)([se, inputs])
    else : 
      out = Multiply()([se, inputs])
    return out
  return call

def conv_block(filters,block_name=None):
  def call(inputs):
    x = inputs

    x = Conv2D(filters, (3, 3), padding="same",use_bias=False,kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Swish()(x)

    x = Conv2D(filters, (3, 3), padding="same",use_bias=False,kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Swish()(x)

    x = squeeze_excite_block(name_block=block_name)(x)

    return x
  return call

def attention_module(n_filter):
  def call(inputs,skip):
    x = Conv2D(n_filter,(1, 1), padding='same', kernel_initializer='he_normal',use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(n_filter, (2,2), strides=(2, 2), padding='same',kernel_initializer = 'he_normal')(x)

    x1= Conv2D(n_filter,(1, 1), padding='same', kernel_initializer='he_normal',use_bias=False)(skip)
    x1 = BatchNormalization()(x1)

    out = Swish()(x1+x)
    out = Conv2D(1,(1, 1), padding='same', kernel_initializer='he_normal',use_bias=False)(out)
    out = BatchNormalization()(out)
    out = Activation('sigmoid')(out)

    return out*skip
  return call
def decoder_block(n_filter,skip=None):
  def call(inputs):
    x= Conv2DTranspose(n_filter, (2,2), strides=(2, 2), padding='same',kernel_initializer = 'he_normal')(inputs)
    out = x
    if skip is not None :
      attention = attention_module(n_filter)(inputs,skip)
      out = Concatenate()([x,attention])
    out = conv_block(n_filter)(out)

    return out
  return call
def dow_block(kernel_size=(2,2),stride=(2,2)):
  def call(inputs):
    out = MaxPooling2D(kernel_size, strides=stride)(inputs)
    return out
  return call


In [None]:
def encoderSegnet(input_s=(128,128,1)):
  down_block = dow_block()
  inp= Input(shape=input_s)
  o = inp
  nums_filter=[64,128,256,512,1024]
  count=0
  for f in nums_filter[:-1]:
    count+=1
    o = conv_block(f,block_name='output_block_'+str(count))(o)
    o = down_block(o)

  o = conv_block(nums_filter[-1],block_name='output_block_'+str(count+1))(o)
  #o = Dropout(0.5)(o)
  return Model(inp,o)


In [None]:
list_skip = ["output_block_4", "output_block_3", "output_block_2", "output_block_1"]

In [None]:
def seg_net(input_shape= (128,128,1), list_skip = list_skip,out_channels=3):
  encoder = encoderSegnet(input_s = input_shape)
  skip_connect=[encoder.get_layer(i).output for i in list_skip]
  num_filters = [512,256, 128, 64]

  o = encoder.output

  
  for i, f in enumerate(num_filters):
    o = decoder_block(f,skip=skip_connect[i])(o)
  
  o = Conv2D(out_channels,(1, 1), padding='same', kernel_initializer='he_normal')(o)
  o = Activation('softmax')(o)

  return Model(encoder.input,o)

In [None]:

S = seg_net()

S.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 128, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 576         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
swish (Swish)                   (None, 128, 128, 64) 0           batch_normalization[0][0]        
_______________________________________________________________________________________