In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Encoder

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import models,layers
from tensorflow.keras.utils import get_file
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import re
import sys
from collections import namedtuple
import tensorflow.keras.backend as K
import math

%cd /content/drive/MyDrive/MRI_ACDCA/
from keras_vision_transformer import swin_layers
from keras_vision_transformer import transformer_layers
from keras_vision_transformer import utils

/content/drive/.shortcut-targets-by-id/1kdKWMtQq063wklZtXKlzGd-nUeOnEiDw/MRI_ACDCA


### Bottleneck

In [None]:
class bottle_neck(tf.keras.layers.Layer):
  def __init__(self, filter=1024):
    self.filter = filter
    self.aspp = ASPP(filter=self.filter)
    self.conv = Conv2D(filters=self.filters, kernel_size=(3,3), strides=(2,2), padding='same')
    self.residual_block1 = residual_block(num_filters=self.filter)
    self.residual_block2 = residual_block(num_filters=self.filter)
    self.leaky_relu = LeakyReLU(alpha=0.1)
    self.drop_out = Dropout(0.5)

  def call(self, x):
    x = self.aspp(x)
    x = self.conv(x)
    x = self.residual_block1(x)
    x = self.residual_block2(x)
    x = self.leaky_relu(x)
    x = self.drop_out(x)

    return x

## Preparation

In [None]:
import keras.backend as K
class Swish(tf.keras.layers.Layer):
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name, **kwargs)

    def call(self, inputs, **kwargs):
        return tf.nn.swish(inputs)

    def get_config(self):
        config = super().get_config()
        config['name'] = self.name
        return config
def squeeze_excite_block(reduce_ratio=0.25,name_block=None):
  def call(inputs):
    filters = inputs.shape[-1]
    num_reduced_filters= max(1, int(filters * reduce_ratio))
    se = Lambda(lambda a: K.mean(a, axis=[1,2], keepdims=True))(inputs)

    se = Conv2D(
            num_reduced_filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer='he_normal',
            padding='same',
            use_bias=True
        )(se)
    se = Swish()(se)
    se = Conv2D(
            filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer='he_normal',
            padding='same',
            use_bias=True
        )(se)
    se = Activation('sigmoid')(se)
    if name_block is not None:
      out = Multiply(name=name_block)([se, inputs])
    else : 
      out = Multiply()([se, inputs])
    return out
  return call

def conv_block(inputs, filters,kernel_size = (3,3), dilation = 1,block_name=None):
    x = inputs

    x = Conv2D(filters, kernel_size, padding="same",dilation_rate =dilation ,use_bias=False,kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Swish()(x)

    x = Conv2D(filters, kernel_size, padding="same",dilation_rate =dilation, use_bias=False,kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Swish()(x)

    x = squeeze_excite_block(name_block=block_name)(x)

    return x

In [None]:
def swin_transformer_stack(X, stack_num, embed_dim, num_patch, num_heads, window_size, num_mlp, shift_window=True, name=''):
    '''
    Stacked Swin Transformers that share the same token size.
    
    Alternated Window-MSA and Swin-MSA will be configured if `shift_window=True`, Window-MSA only otherwise.
    *Dropout is turned off.
    '''
    # Turn-off dropouts
    mlp_drop_rate = 0 # Droupout after each MLP layer
    attn_drop_rate = 0 # Dropout after Swin-Attention
    proj_drop_rate = 0 # Dropout at the end of each Swin-Attention block, i.e., after linear projections
    drop_path_rate = 0 # Drop-path within skip-connections
    
    qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value
    qk_scale = None # None: Re-scale query based on embed dimensions per attention head # Float for user specified scaling factor
    
    if shift_window:
        shift_size = window_size // 2
    else:
        shift_size = 0
    
    for i in range(stack_num):
    
        if i % 2 == 0:
            shift_size_temp = 0
        else:
            shift_size_temp = shift_size

        X = swin_layers.SwinTransformerBlock(dim=embed_dim, num_patch=num_patch, num_heads=num_heads, 
                                 window_size=window_size, shift_size=shift_size_temp, num_mlp=num_mlp, qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 mlp_drop=mlp_drop_rate, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, drop_path_prob=drop_path_rate, 
                                 name='name{}'.format(i))(X)
    return X

In [None]:
def mvn(tensor):
    '''Performs per-channel spatial mean-variance normalization.'''
    epsilon = 1e-6
    mean = K.mean(tensor, axis=(1,2), keepdims=True)
    std = K.std(tensor, axis=(1,2), keepdims=True)
    mvn = (tensor - mean) / (std + epsilon)
    
    return mvn

In [None]:
def output_block(inputs):
    x = Conv2D(1, (1, 1), padding="same")(inputs)
    x = Activation('sigmoid')(x)
    return x

def output_block1(inputs):
    x = Conv2D(3, (1, 1), padding="same")(inputs)
    x = Activation('softmax')(x)
    return x

In [None]:
def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='swish', name=None):
    '''
    2D Convolutional layers
    
    Arguments:
        x {keras layer} -- input layer 
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters
    
    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(1, 1)})
        activation {str} -- activation function (default: {'relu'})
        name {str} -- name of the layer (default: {None})
    
    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    if(activation == None):
        return x

    x = Activation(activation, name=name)(x)

    return x

def ResPath(filters, length, inp):
    '''
    ResPath
    
    Arguments:
        filters {int} -- [description]
        length {int} -- length of ResPath
        inp {keras layer} -- input layer 
    
    Returns:
        [keras layer] -- [output layer]
    '''

    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='swish', padding='same')

    out = add([shortcut, out])
    out = Activation('swish')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length-1):

        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='swish', padding='same')

        out = add([shortcut, out])
        out = Activation('swish')(out)
        out = BatchNormalization(axis=3)(out)

    return out

In [None]:
def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True):
    x = Conv2D(filters, size, strides=strides, padding=padding)(x)
    x = BatchNormalization()(x)
    if activation == True:
        x = Swish()(x)
    return x

def residual_block(blockInput, num_filters=16):
    x = Swish()(blockInput)
    x = BatchNormalization()(x)
    blockInput = BatchNormalization()(blockInput)
    x = convolution_block(x, num_filters, (3,3))
    x = convolution_block(x, num_filters, (3,3), activation=False) #here originally no activation
    x = squeeze_excite_block()(x)
    x = Add()([x, blockInput])
    return x

In [None]:
def up_and_concate(down_layer, layer, data_format='channels_last'):
    data_format='channels_last'
    if data_format == 'channels_first':
        in_channel = down_layer.get_shape().as_list()[1]
    else:
        in_channel = down_layer.get_shape().as_list()[3]

    # up = Conv2DTranspose(out_channel, [2, 2], strides=[2, 2])(down_layer)
    up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer)

    if data_format == 'channels_first':
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1))
    else:
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))

    concate = my_concat([up, layer])

    return concate
def attention_up_and_concate(down_layer, layer, data_format='channels_last'):
    data_format='channels_last'
    if data_format == 'channels_first':
        in_channel = down_layer.get_shape().as_list()[1]
    else:
        in_channel = down_layer.get_shape().as_list()[3]

    # up = Conv2DTranspose(out_channel, [2, 2], strides=[2, 2])(down_layer)
    up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer)

    layer = attention_block_2d(x=layer, g=up, inter_channel=in_channel // 4, data_format=data_format)

    if data_format == 'channels_first':
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1))
    else:
        my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))

    concate = my_concat([up, layer])
    return concate
def attention_block_2d(x, g, inter_channel, data_format='channels_last'):
    data_format='channels_last'
    # theta_x(?,g_height,g_width,inter_channel)

    theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(x)

    # phi_g(?,g_height,g_width,inter_channel)

    phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(g)

    # f(?,g_height,g_width,inter_channel)

    f = Activation('swish')(add([theta_x, phi_g]))

    # psi_f(?,g_height,g_width,1)

    psi_f = Conv2D(1, [1, 1], strides=[1, 1], data_format=data_format)(f)

    rate = Activation('sigmoid')(psi_f)

    # rate(?,x_height,x_width)

    # att_x(?,x_height,x_width,x_channel)

    att_x = multiply([x, rate])

    return att_x
def res_block(input_layer, out_n_filters, batch_normalization=False, kernel_size=[3, 3], stride=[1, 1],

              padding='same', data_format='channels_first'):
    data_format='channels_last'
    if data_format == 'channels_first':
        input_n_filters = input_layer.get_shape().as_list()[1]
    else:
        input_n_filters = input_layer.get_shape().as_list()[3]

    layer = input_layer
    for i in range(2):
        layer = Conv2D(out_n_filters // 4, [1, 1], strides=stride, padding=padding, data_format=data_format)(layer)
        if batch_normalization:
            layer = BatchNormalization()(layer)
        layer = Activation('swish')(layer)
        layer = Conv2D(out_n_filters // 4, kernel_size, strides=stride, padding=padding, data_format=data_format)(layer)
        layer = Conv2D(out_n_filters, [1, 1], strides=stride, padding=padding, data_format=data_format)(layer)

    if out_n_filters != input_n_filters:
        skip_layer = Conv2D(out_n_filters, [1, 1], strides=stride, padding=padding, data_format=data_format)(
            input_layer)
    else:
        skip_layer = input_layer
    out_layer = add([layer, skip_layer])
    return out_layer


# Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net)
def rec_res_block(input_layer, out_n_filters, batch_normalization=False, kernel_size=[3, 3], stride=[1, 1],

                  padding='same', data_format='channels_first'):
    data_format='channels_last'
    if data_format == 'channels_first':
        input_n_filters = input_layer.get_shape().as_list()[1]
    else:
        input_n_filters = input_layer.get_shape().as_list()[3]

    if out_n_filters != input_n_filters:
        skip_layer = Conv2D(out_n_filters, [1, 1], strides=stride, padding=padding, data_format=data_format)(
            input_layer)
    else:
        skip_layer = input_layer

    layer = skip_layer
    for j in range(2):

        for i in range(2):
            if i == 0:

                layer1 = Conv2D(out_n_filters, kernel_size, strides=stride, padding=padding, data_format=data_format)(
                    layer)
                if batch_normalization:
                    layer1 = BatchNormalization()(layer1)
                layer1 = Activation('relu')(layer1)
            layer1 = Conv2D(out_n_filters, kernel_size, strides=stride, padding=padding, data_format=data_format)(
                add([layer1, layer]))
            if batch_normalization:
                layer1 = BatchNormalization()(layer1)
            layer1 = Activation('relu')(layer1)
        layer = layer1

    out_layer = add([layer, skip_layer])
    return out_layer

In [None]:
class encoder_layer(tf.keras.layers.Layer):
  def __init__(self, filters, kernel_size=3, strides=(1,1), padding='same', dilation_rate=1):
    super(encoder_layer, self).__init__()
    self.filters = filters
    self.kernel_size = kernel_size
    self.strides = strides
    self.padding = padding
    self.dilation_rate = dilation_rate

    self.conv = Conv2D(filters=self.filters, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, dilation_rate=self.dilation_rate)
    self.bn = BatchNormalization()
    self.swish = Swish()
    self.squeeze_and_excite = squeeze_excite_block()
    self.maxpool = MaxPooling2D(pool_size=(2, 2))

  def call(self, x, training = True):
    x = self.conv(x)
    x = self.bn(x, training=training)
    x = self.swish(x)
    x = self.conv(x)
    x = self.bn(x, training=training)
    x = self.swish(x)
    x = self.squeeze_and_excite(x)
    x = self.maxpool(x)
    
    return x

## SegNet

In [None]:
def ASPP(x, filter):
    shape = x.shape

    y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
    y1 = Conv2D(filter, 1, padding="same",use_bias=False,kernel_initializer='he_normal')(y1)
    y1 = BatchNormalization()(y1)
    y1 = Swish()(y1)
    y1 = UpSampling2D((shape[1], shape[2]), interpolation='bilinear')(y1)
    y1 = squeeze_excite_block()(y1)

    y2 = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y2 = BatchNormalization()(y2)
    y2 = Swish()(y2)
    y2 = squeeze_excite_block()(y2)

    y3 = Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y3 = BatchNormalization()(y3)
    y3 = Swish()(y3)
    y3 = squeeze_excite_block()(y3)

    y4 = Conv2D(filter, 5, dilation_rate=12, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y4 = BatchNormalization()(y4)
    y4 = Swish()(y4)
    y4 = squeeze_excite_block()(y4)

    y5 = Conv2D(filter, 7, dilation_rate=18, padding="same", use_bias=False,kernel_initializer='he_normal')(x)
    y5 = BatchNormalization()(y5)
    y5 = Swish()(y5)
    y5 = squeeze_excite_block()(y5)

    y = Concatenate()([y1, y2, y3, y4, y5])

    y = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False,kernel_initializer='he_normal')(y)
    y = BatchNormalization()(y)
    y = Swish()(y)
    y = squeeze_excite_block()(y)
    return y

In [None]:
#generator = seg_net(input_shape = (128,128,1), out_channels=3)
#generator.summary()

## Residual Attention

In [None]:
def residual_attention_concate_UNet(input_shape=(128,128,1), out_channels=3):
    input = Input(shape=input_shape, dtype='float', name='data')
    batchnorm1 = BatchNormalization()(input)
    conv1 = Conv2D(64, 3, padding = 'same')(batchnorm1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = Conv2D(64, 3,  padding = 'same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = squeeze_excite_block(name_block='conv1')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3,  padding = 'same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = Conv2D(128, 3,  padding = 'same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = squeeze_excite_block(name_block='conv2')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3,  padding = 'same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = Conv2D(256, 3,  padding = 'same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = squeeze_excite_block(name_block='conv3')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3,  padding = 'same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = Conv2D(512, 3,  padding = 'same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = squeeze_excite_block(name_block='conv4')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    pool4 = ASPP(pool4,1024)

    #bottle_neck
    conv5 = Conv2D(1024, (3, 3), activation=None, padding="same")(pool4)
    conv5 = residual_block(conv5, 1024)
    conv5 = residual_block(conv5, 1024)
    conv5 = LeakyReLU(alpha=0.1)(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    #conv6 = attention_up_and_concate(conv5,conv4) #conv6 = concatenate([deconv6, conv4])
    deconv6 = Conv2DTranspose(512, (3, 3), strides=(2, 2), padding="same")(conv5)
    conv6 = concatenate([deconv6, conv4])
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(512, (3, 3), activation=None, padding="same")(conv6)
    conv6 = residual_block(conv6,512)
    conv6 = residual_block(conv6,512)
    conv6 = LeakyReLU(alpha=0.1)(conv6)
    
    #conv7 = attention_up_and_concate(conv6,conv3) #conv7 = concatenate([deconv7, conv3])  
    deconv7 = Conv2DTranspose(256, (3, 3), strides=(2, 2), padding="same")(conv6) 
    conv7 = concatenate([deconv7, conv3])  
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(256, (3, 3), activation=None, padding="same")(conv7)
    conv7 = residual_block(conv7,256)
    conv7 = residual_block(conv7,256)
    conv7 = LeakyReLU(alpha=0.1)(conv7)

    #conv8 = attention_up_and_concate(conv7,conv2) #conv8 = concatenate([deconv8,conv2])  
    deconv8 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding="same")(conv7)
    conv8 = concatenate([deconv8,conv2])  
    conv8 = Dropout(0.2)(conv8)
    conv8 = Conv2D(128, (3, 3), activation=None, padding="same")(conv8)
    conv8 = residual_block(conv8,128)
    conv8 = residual_block(conv8,128)
    conv8 = LeakyReLU(alpha=0.1)(conv8)
    
    #conv9 = attention_up_and_concate(conv8,conv1) #conv9 = concatenate([deconv9, conv1])  
    deconv9 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same")(conv8)  
    conv9 = concatenate([deconv9, conv1])
    conv9 = Dropout(0.2)(conv9)
    conv9 = Conv2D(64, (3, 3), activation=None, padding="same")(conv9)
    conv9 = residual_block(conv9,64)
    conv9 = residual_block(conv9,64)
    conv9 = LeakyReLU(alpha=0.1)(conv9)
    
    '''conv10 = Conv2DTranspose(16, (3, 3), strides=(2, 2), padding="same")(conv9)   
    conv10 = Dropout(0.1)(conv10)
    conv10 = Conv2D(16, (3, 3), activation=None, padding="same")(conv10)
    conv10 = residual_block(conv10,16)
    conv10 = residual_block(conv10,16)
    conv10 = LeakyReLU(alpha=0.1)(conv10)'''
    
    #conv10 = Dropout(0.1)(conv10)
    output = Conv2D(out_channels, (1,1), padding="same", activation="softmax")(conv9)  
    model = Model(input, output)
    #model.name = 'u-xception'

    return model

#shape=(128,128,1)
#inputs=Input(shape)
#outputs1,model = residual_attention_concate_UNet(inputs,out_channels=1)

#skip=[]
#skip.append(model.get_layer('conv4').output)
#skip.append(model.get_layer('conv3').output)
#skip.append(model.get_layer('conv2').output)
#skip.append(model.get_layer('conv1').output)

In [None]:
generator = residual_attention_concate_UNet(input_shape=(128, 128, 1), out_channels=3)
generator.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
data (InputLayer)               [(None, 128, 128, 1) 0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 1)  4           data[0][0]                       
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 640         batch_normalization[0][0]        
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 64) 256         conv2d[0][0]                     
______________________________________________________________________________________________

## Mo-UNet (Swish)

In [None]:
def mo_unet(input_size = (128,128,1), out_channels=3):
    input = Input(shape=input_size, dtype='float', name='data')
    batchnorm1 = BatchNormalization()(input)
    conv1 = Conv2D(32, 3, padding = 'same')(batchnorm1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = Conv2D(32, 3,  padding = 'same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = squeeze_excite_block()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3,  padding = 'same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = Conv2D(64, 3,  padding = 'same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = squeeze_excite_block()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3,  padding = 'same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = Conv2D(128, 3,  padding = 'same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = squeeze_excite_block()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, 3,  padding = 'same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = Conv2D(256, 3,  padding = 'same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = squeeze_excite_block()(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    pool4 = ASPP(pool4,512)

    #bottle_neck
    conv5 = Conv2D(512, (3, 3), activation=None, padding="same")(pool4)
    conv5 = residual_block(conv5, 512)
    conv5 = residual_block(conv5, 512)
    conv5 = LeakyReLU(alpha=0.1)(conv5)
    drop5 = Dropout(0.5)(conv5)

    merge6 = attention_up_and_concate(conv5,conv4)
    conv6 = Conv2D(256, 3,  padding = 'same')(merge6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Swish()(conv6)
    conv6 = Conv2D(256, 3,  padding = 'same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Swish()(conv6)
    conv6 = squeeze_excite_block()(conv6)

    merge7 = attention_up_and_concate(conv6,conv3)
    conv7 = Conv2D(128, 3,  padding = 'same')(merge7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Swish()(conv7)
    conv7 = Conv2D(128, 3,  padding = 'same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Swish()(conv7)
    conv7 = squeeze_excite_block()(conv7)

    merge8 = attention_up_and_concate(conv7,conv2)
    conv8 = Conv2D(64, 3,  padding = 'same')(merge8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Swish()(conv8)
    conv8 = Conv2D(64, 3,  padding = 'same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Swish()(conv8)
    conv8 = squeeze_excite_block()(conv8)

    merge9 = attention_up_and_concate(conv8,conv1)
    conv9 = Conv2D(32, 3,  padding = 'same')(merge9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Swish()(conv9)
    conv9 = Conv2D(32, 3,  padding = 'same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Swish()(conv9)
    conv9 = squeeze_excite_block()(conv9)
    conv9 = Conv2D(out_channels, 3,  padding = 'same')(conv9)
    conv9 = Activation('softmax')(conv9)
    #conv10 = Conv2D(1, 1, activation = 'softmax')(conv9)

    model = Model(input, conv9)

    return model

In [None]:
#attention_unet = mo_unet(input_size = (128,128,1), out_channels=3).output

## Swin UNet

In [None]:
filter_num_begin = 32     # number of channels in the first downsampling block; it is also the number of embedded dimensions
depth = 4                  # the depth of SwinUNET; depth=4 means three down/upsampling levels and a bottom level 
stack_num_down = 2         # number of Swin Transformers per downsampling level
stack_num_up = 2           # number of Swin Transformers per upsampling level
patch_size = (4, 4)        # Extract 2-by-2 patches from the input image. Height and width of the patch must be equal.
num_heads = [4, 8, 8, 8]   # number of attention heads per down/upsampling level
window_size = [4, 2, 2, 2] # the size of attention window per down/upsampling level
num_mlp = 512              # number of MLP nodes within the Transformer
shift_window=True          # Apply window shifting, i.e., Swin-MSA

In [None]:
def swin_unet_2d_base(input_tensor, filter_num_begin, depth, stack_num_down, stack_num_up, 
                      patch_size, num_heads, window_size, num_mlp, shift_window=True, name='swin_unet'):
    '''
    The base of SwinUNET.
    '''
    # Compute number be patches to be embeded
    input_size = input_tensor.shape.as_list()[1:]
    num_patch_x = input_size[0]//patch_size[0]
    num_patch_y = input_size[1]//patch_size[1]
    
    # Number of Embedded dimensions
    embed_dim = filter_num_begin
    
    depth_ = depth
    
    X_skip = []

    X = input_tensor
    
    # Patch extraction
    X = transformer_layers.patch_extract(patch_size)(X)

    # Embed patches to tokens
    X = transformer_layers.patch_embedding(num_patch_x*num_patch_y, embed_dim)(X)
    
    # The first Swin Transformer stack
    X = swin_transformer_stack(X, stack_num=stack_num_down, 
                               embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 
                               num_heads=num_heads[0], window_size=window_size[0], num_mlp=num_mlp, 
                               shift_window=shift_window, name='{}_swin_down0'.format(name))
    X_skip.append(X)
    
    # Downsampling blocks
    for i in range(depth_-1):
        
        # Patch merging
        X = transformer_layers.patch_merging((num_patch_x, num_patch_y), embed_dim=embed_dim, name='down{}'.format(i))(X)
        
        # update token shape info
        embed_dim = embed_dim*2
        num_patch_x = num_patch_x//2
        num_patch_y = num_patch_y//2
        
        # Swin Transformer stacks
        X = swin_transformer_stack(X, stack_num=stack_num_down, 
                                   embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 
                                   num_heads=num_heads[i+1], window_size=window_size[i+1], num_mlp=num_mlp, 
                                   shift_window=shift_window, name='{}_swin_down{}'.format(name, i+1))
        
        # Store tensors for concat
        X_skip.append(X)
        
    # reverse indexing encoded tensors and hyperparams
    X_skip = X_skip[::-1]
    num_heads = num_heads[::-1]
    window_size = window_size[::-1]
    
    # upsampling begins at the deepest available tensor
    X = X_skip[0]
    
    # other tensors are preserved for concatenation
    X_decode = X_skip[1:]
    
    depth_decode = len(X_decode)
    
    for i in range(depth_decode):
        
        # Patch expanding
        X = transformer_layers.patch_expanding(num_patch=(num_patch_x, num_patch_y), 
                                               embed_dim=embed_dim, upsample_rate=2, return_vector=True)(X)
        

        # update token shape info
        embed_dim = embed_dim//2
        num_patch_x = num_patch_x*2
        num_patch_y = num_patch_y*2
        
        # Concatenation and linear projection
        X = concatenate([X, X_decode[i]], axis=-1, name='{}_concat_{}'.format(name, i))
        X = Dense(embed_dim, use_bias=False, name='{}_concat_linear_proj_{}'.format(name, i))(X)
        
        # Swin Transformer stacks
        X = swin_transformer_stack(X, stack_num=stack_num_up, 
                           embed_dim=embed_dim, num_patch=(num_patch_x, num_patch_y), 
                           num_heads=num_heads[i], window_size=window_size[i], num_mlp=num_mlp, 
                           shift_window=shift_window, name='{}_swin_up{}'.format(name, i))
        
    # The last expanding layer; it produces full-size feature maps based on the patch size
    # !!! <--- "patch_size[0]" is used; it assumes patch_size = (size, size)
    X = transformer_layers.patch_expanding(num_patch=(num_patch_x, num_patch_y), 
                                               embed_dim=embed_dim, upsample_rate=patch_size[0], return_vector=False)(X)
    print(X.shape)
    print(X_skip)
    
    return X, X_skip

In [None]:
input_size = (128, 128, 1)
IN = Input(input_size)

X, X_skip = swin_unet_2d_base(IN, filter_num_begin, depth, stack_num_down, stack_num_up, 
                  patch_size, num_heads, window_size, num_mlp, shift_window=shift_window, name='swin_unet')
'''
n_labels = 1
OUT = Conv2D(n_labels, kernel_size=1, use_bias=False, activation='sigmoid')(X)

# Model configuration
swin_unet = Model(inputs=[IN,], outputs=[OUT,])
#generator.summary()'''

(None, 128, 128, 8)
[<KerasTensor: shape=(None, 16, 256) dtype=float32 (created by layer 'swin_transformer_block_7')>, <KerasTensor: shape=(None, 64, 128) dtype=float32 (created by layer 'swin_transformer_block_5')>, <KerasTensor: shape=(None, 256, 64) dtype=float32 (created by layer 'swin_transformer_block_3')>, <KerasTensor: shape=(None, 1024, 32) dtype=float32 (created by layer 'swin_transformer_block_1')>]


"\nn_labels = 1\nOUT = Conv2D(n_labels, kernel_size=1, use_bias=False, activation='sigmoid')(X)\n\n# Model configuration\nswin_unet = Model(inputs=[IN,], outputs=[OUT,])\n#generator.summary()"

## EfficientUnet

In [None]:
GlobalParams = namedtuple('GlobalParams', ['batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
                                           'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
                                           'drop_connect_rate'])
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)

BlockArgs = namedtuple('BlockArgs', ['kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'expand_ratio',
                                     'id_skip', 'strides', 'se_ratio'])
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)

IMAGENET_WEIGHTS = {

    'efficientnet-b0': {
        'name': 'efficientnet-b0_imagenet_1000.h5',
        'url': 'https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b0_imagenet_1000.h5',
        'md5': 'bca04d16b1b8a7c607b1152fe9261af7',
    },

    'efficientnet-b1': {
        'name': 'efficientnet-b1_imagenet_1000.h5',
        'url': 'https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b1_imagenet_1000.h5',
        'md5': 'bd4a2b82f6f6bada74fc754553c464fc',
    },

    'efficientnet-b2': {
        'name': 'efficientnet-b2_imagenet_1000.h5',
        'url': 'https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b2_imagenet_1000.h5',
        'md5': '45b28b26f15958bac270ab527a376999',
    },

    'efficientnet-b3': {
        'name': 'efficientnet-b3_imagenet_1000.h5',
        'url': 'https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b3_imagenet_1000.h5',
        'md5': 'decd2c8a23971734f9d3f6b4053bf424',
    },

    'efficientnet-b4': {
        'name': 'efficientnet-b4_imagenet_1000.h5',
        'url': 'https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b4_imagenet_1000.h5',
        'md5': '01df77157a86609530aeb4f1f9527949',
    },

    'efficientnet-b5': {
        'name': 'efficientnet-b5_imagenet_1000.h5',
        'url': 'https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b5_imagenet_1000.h5',
        'md5': 'c31311a1a38b5111e14457145fccdf32',
    }

}


def round_filters(filters, global_params):
    """Round number of filters."""
    multiplier = global_params.width_coefficient
    divisor = global_params.depth_divisor
    min_depth = global_params.min_depth
    if not multiplier:
        return filters

    filters *= multiplier
    min_depth = min_depth or divisor
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_filters < 0.9 * filters:
        new_filters += divisor
    return int(new_filters)


def round_repeats(repeats, global_params):
    """Round number of repeats."""
    multiplier = global_params.depth_coefficient
    if not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))


def get_efficientnet_params(model_name, override_params=None):
    """Get efficientnet params based on model name."""
    params_dict = {
        # (width_coefficient, depth_coefficient, resolution, dropout_rate)
        # Note: the resolution here is just for reference, its values won't be used.
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
        'efficientnet-b4': (1.4, 1.8, 380, 0.3),
        'efficientnet-b5': (1.6, 2.2, 456, 0.3),
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
    }
    if model_name not in params_dict.keys():
        raise KeyError('There is no model named {}.'.format(model_name))

    width_coefficient, depth_coefficient, _, dropout_rate = params_dict[model_name]

    blocks_args = [
        'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
        'r1_k3_s11_e6_i192_o320_se0.25',
    ]
    global_params = GlobalParams(
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        dropout_rate=dropout_rate,
        drop_connect_rate=0.2,
        num_classes=1000,
        width_coefficient=width_coefficient,
        depth_coefficient=depth_coefficient,
        depth_divisor=8,
        min_depth=None)

    if override_params:
        global_params = global_params._replace(**override_params)

    decoder = BlockDecoder()
    return decoder.decode(blocks_args), global_params


class BlockDecoder(object):
    """Block Decoder for readability."""

    @staticmethod
    def _decode_block_string(block_string):
        """Gets a block through a string notation of arguments."""
        assert isinstance(block_string, str)
        ops = block_string.split('_')
        options = {}
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

        if 's' not in options or len(options['s']) != 2:
            raise ValueError('Strides options should be a pair of integers.')

        return BlockArgs(
            kernel_size=int(options['k']),
            num_repeat=int(options['r']),
            input_filters=int(options['i']),
            output_filters=int(options['o']),
            expand_ratio=int(options['e']),
            id_skip=('noskip' not in block_string),
            se_ratio=float(options['se']) if 'se' in options else None,
            strides=[int(options['s'][0]), int(options['s'][1])]
        )

    @staticmethod
    def _encode_block_string(block):
        """Encodes a block to a string."""
        args = [
            'r%d' % block.num_repeat,
            'k%d' % block.kernel_size,
            's%d%d' % (block.strides[0], block.strides[1]),
            'e%s' % block.expand_ratio,
            'i%d' % block.input_filters,
            'o%d' % block.output_filters
        ]
        if 0 < block.se_ratio <= 1:
            args.append('se%s' % block.se_ratio)
        if block.id_skip is False:
            args.append('noskip')
        return '_'.join(args)

    def decode(self, string_list):
        """Decodes a list of string notations to specify blocks inside the network.
        Args:
          string_list: a list of strings, each string is a notation of block.
        Returns:
          A list of namedtuples to represent blocks arguments.
        """
        assert isinstance(string_list, list)
        blocks_args = []
        for block_string in string_list:
            blocks_args.append(self._decode_block_string(block_string))
        return blocks_args

    def encode(self, blocks_args):
        """Encodes a list of Blocks to a list of strings.
        Args:
          blocks_args: A list of namedtuples to represent blocks arguments.
        Returns:
          a list of strings, each string is a notation of block.
        """
        block_strings = []
        for block in blocks_args:
            block_strings.append(self._encode_block_string(block))
        return block_strings


class Swish(layers.Layer):
    def __init__(self, name=None, **kwargs):
        super().__init__(name=name, **kwargs)

    def call(self, inputs, **kwargs):
        return tf.nn.swish(inputs)

    def get_config(self):
        config = super().get_config()
        config['name'] = self.name
        return config


def SEBlock(block_args, **kwargs):
    num_reduced_filters = max(
        1, int(block_args.input_filters * block_args.se_ratio))
    filters = block_args.input_filters * block_args.expand_ratio

    spatial_dims = [1, 2]

    try:
        block_name = kwargs['block_name']
    except KeyError:
        block_name = ''

    def block(inputs):
        x = inputs
        x = layers.Lambda(lambda a: K.mean(a, axis=spatial_dims, keepdims=True))(x)
        x = layers.Conv2D(
            num_reduced_filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            name=block_name + 'se_reduce_conv2d',
            use_bias=True
        )(x)

        x = Swish(name=block_name + 'se_swish')(x)

        x = layers.Conv2D(
            filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            name=block_name + 'se_expand_conv2d',
            use_bias=True
        )(x)

        x = layers.Activation('sigmoid')(x)
        out = layers.Multiply()([x, inputs])
        return out

    return block


class DropConnect(layers.Layer):

    def __init__(self, drop_connect_rate, **kwargs):
        super().__init__(**kwargs)
        self.drop_connect_rate = drop_connect_rate

    def call(self, inputs, **kwargs):
        def drop_connect():
            keep_prob = 1.0 - self.drop_connect_rate

            # Compute drop_connect tensor
            batch_size = tf.shape(inputs)[0]
            random_tensor = keep_prob
            random_tensor += tf.random.uniform([batch_size, 1, 1, 1], dtype=inputs.dtype)
            binary_tensor = tf.floor(random_tensor)
            output = tf.math.divide(inputs, keep_prob) * binary_tensor
            return output

        return K.in_train_phase(drop_connect(), inputs, training=None)

    def get_config(self):
        config = super().get_config()
        config['drop_connect_rate'] = self.drop_connect_rate
        return config


def conv_kernel_initializer(shape, dtype=K.floatx()):

    kernel_height, kernel_width, _, out_filters = shape
    fan_out = int(kernel_height * kernel_width * out_filters)
    return tf.random.normal(
        shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype)


def dense_kernel_initializer(shape, dtype=K.floatx()):
    init_range = 1.0 / np.sqrt(shape[1])
    return tf.random.uniform(shape, -init_range, init_range, dtype=dtype)


def MBConvBlock(block_args, global_params, idx, drop_connect_rate=None):
    filters = block_args.input_filters * block_args.expand_ratio
    batch_norm_momentum = global_params.batch_norm_momentum
    batch_norm_epsilon = global_params.batch_norm_epsilon
    has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)

    block_name = 'blocks_' + str(idx) + '_'

    def block(inputs):
        x = inputs

        # Expansion phase
        if block_args.expand_ratio != 1:
            expand_conv = layers.Conv2D(filters,
                                        kernel_size=[1, 1],
                                        strides=[1, 1],
                                        kernel_initializer=conv_kernel_initializer,
                                        padding='same',
                                        use_bias=False,
                                        name=block_name + 'expansion_conv2d'
                                        )(x)
            bn0 = layers.BatchNormalization(momentum=batch_norm_momentum,
                                            epsilon=batch_norm_epsilon,
                                            name=block_name + 'expansion_batch_norm')(expand_conv)

            x = Swish(name=block_name + 'expansion_swish')(bn0)

        # Depth-wise convolution phase
        kernel_size = block_args.kernel_size
        depthwise_conv = layers.DepthwiseConv2D(
            [kernel_size, kernel_size],
            strides=block_args.strides,
            depthwise_initializer=conv_kernel_initializer,
            padding='same',
            use_bias=False,
            name=block_name + 'depthwise_conv2d'
        )(x)
        bn1 = layers.BatchNormalization(momentum=batch_norm_momentum,
                                        epsilon=batch_norm_epsilon,
                                        name=block_name + 'depthwise_batch_norm'
                                        )(depthwise_conv)
        x = Swish(name=block_name + 'depthwise_swish')(bn1)

        if has_se:
            x = SEBlock(block_args, block_name=block_name)(x)

        # Output phase
        project_conv = layers.Conv2D(
            block_args.output_filters,
            kernel_size=[1, 1],
            strides=[1, 1],
            kernel_initializer=conv_kernel_initializer,
            padding='same',
            name=block_name + 'output_conv2d',
            use_bias=False)(x)
        x = layers.BatchNormalization(momentum=batch_norm_momentum,
                                      epsilon=batch_norm_epsilon,
                                      name=block_name + 'output_batch_norm'
                                      )(project_conv)
        if block_args.id_skip:
            if all(
                    s == 1 for s in block_args.strides
            ) and block_args.input_filters == block_args.output_filters:
                # only apply drop_connect if skip presents.
                if drop_connect_rate:
                    x = DropConnect(drop_connect_rate)(x)
                x = layers.add([x, inputs])

        return x

    return block


def freeze_efficientunet_first_n_blocks(model, n):
    mbblock_nr = 0
    while True:
        try:
            model.get_layer('blocks_{}_output_batch_norm'.format(mbblock_nr))
            mbblock_nr += 1
        except ValueError:
            break

    all_block_names = ['blocks_{}_output_batch_norm'.format(i) for i in range(mbblock_nr)]
    all_block_index = []
    for idx, layer in enumerate(model.layers):
        if layer.name == all_block_names[0]:
            all_block_index.append(idx)
            all_block_names.pop(0)
            if len(all_block_names) == 0:
                break
    n_blocks = len(all_block_index)

    if n <= 0:
        print('n is less than or equal to 0, therefore no layer will be frozen.')
        return
    if n > n_blocks:
        raise ValueError("There are {} blocks in total, n cannot be greater than {}.".format(n_blocks, n_blocks))

    idx_of_last_block_to_be_frozen = all_block_index[n - 1]
    for layer in model.layers[:idx_of_last_block_to_be_frozen + 1]:
        layer.trainable = False


def unfreeze_efficientunet(model):
    for layer in model.layers:
        layer.trainable = True


In [None]:
__all__ = ['get_model_by_name', 'get_efficientnet_b0_encoder', 'get_efficientnet_b1_encoder',
           'get_efficientnet_b2_encoder', 'get_efficientnet_b3_encoder', 'get_efficientnet_b4_encoder',
           'get_efficientnet_b5_encoder', 'get_efficientnet_b6_encoder', 'get_efficientnet_b7_encoder']


def efficientnet(input_shape, blocks_args_list, global_params):
    batch_norm_momentum = global_params.batch_norm_momentum
    batch_norm_epsilon = global_params.batch_norm_epsilon

    # Stem part
    model_input = layers.Input(shape=input_shape)
    x = layers.Conv2D(
        filters=round_filters(32, global_params),
        kernel_size=[3, 3],
        strides=[2, 2],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        use_bias=False,
        name='stem_conv2d'
    )(model_input)

    x = layers.BatchNormalization(
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon,
        name='stem_batch_norm'
    )(x)

    x = Swish(name='stem_swish')(x)

    # Blocks part
    idx = 0
    drop_rate = global_params.drop_connect_rate
    n_blocks = sum([blocks_args.num_repeat for blocks_args in blocks_args_list])
    drop_rate_dx = drop_rate / n_blocks

    for blocks_args in blocks_args_list:
        assert blocks_args.num_repeat > 0
        # Update block input and output filters based on depth multiplier.
        blocks_args = blocks_args._replace(
            input_filters=round_filters(blocks_args.input_filters, global_params),
            output_filters=round_filters(blocks_args.output_filters, global_params),
            num_repeat=round_repeats(blocks_args.num_repeat, global_params)
        )

        # The first block needs to take care of stride and filter size increase.
        x = MBConvBlock(blocks_args, global_params, idx, drop_connect_rate=drop_rate_dx * idx)(x)
        idx += 1

        if blocks_args.num_repeat > 1:
            blocks_args = blocks_args._replace(input_filters=blocks_args.output_filters, strides=[1, 1])

        for _ in range(blocks_args.num_repeat - 1):
            x = MBConvBlock(blocks_args, global_params, idx, drop_connect_rate=drop_rate_dx * idx)(x)
            idx += 1

    # Head part
    x = layers.Conv2D(
        filters=round_filters(1280, global_params),
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        use_bias=False,
        name='head_conv2d'
    )(x)

    x = layers.BatchNormalization(
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon,
        name='head_batch_norm'
    )(x)

    x = Swish(name='head_swish')(x)

    x = layers.GlobalAveragePooling2D(name='global_average_pooling2d')(x)

    if global_params.dropout_rate > 0:
        x = layers.Dropout(global_params.dropout_rate)(x)

    x = layers.Dense(
        global_params.num_classes,
        kernel_initializer=dense_kernel_initializer,
        activation='softmax',
        name='head_dense'
    )(x)

    model = models.Model(model_input, x)

    return model


def get_model_by_name(model_name, input_shape, classes=1000, pretrained=False):
    """Get an EfficientNet model by its name.
    """
    blocks_args, global_params = get_efficientnet_params(model_name, override_params={'num_classes': classes})
    model = efficientnet(input_shape, blocks_args, global_params)

    try:
        if pretrained:
            weights = IMAGENET_WEIGHTS[model_name]
            weights_path = get_file(
                weights['name'],
                weights['url'],
                cache_subdir='models',
                md5_hash=weights['md5'],
            )
            model.load_weights(weights_path)
    except KeyError as e:
        print("NOTE: Currently model {} doesn't have pretrained weights, therefore a model with randomly initialized"
              " weights is returned.".format(e))

    return model


def get_efficientnet_encoder(model_name, input_shape, pretrained=False):
    model = get_model_by_name(model_name, input_shape, pretrained=pretrained)
    encoder = models.Model(model.input, model.get_layer('global_average_pooling2d').output)
    encoder.layers.pop()  # remove GAP layer
    return encoder


def get_efficientnet_b0_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b0', input_shape, pretrained=pretrained)


def get_efficientnet_b1_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b1', input_shape, pretrained=pretrained)


def get_efficientnet_b2_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b2', input_shape, pretrained=pretrained)


def get_efficientnet_b3_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b3', input_shape, pretrained=pretrained)


def get_efficientnet_b4_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b4', input_shape, pretrained=pretrained)


def get_efficientnet_b5_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b5', input_shape, pretrained=pretrained)


def get_efficientnet_b6_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b6', input_shape, pretrained=pretrained)


def get_efficientnet_b7_encoder(input_shape, pretrained=False):
    return get_efficientnet_encoder('efficientnet-b7', input_shape, pretrained=pretrained)


In [None]:
def BottleNeck1():
  def call(inputs):
    x = Conv2D(inputs.shape[-1],kernel_size=1,padding='same',kernel_initializer='he_normal',use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Conv2D(inputs.shape[-1],kernel_size=1,padding='same',kernel_initializer='he_normal',use_bias=False)(x)
    x = BatchNormalization()(x)
    x  =LeakyReLU(0.2)(x)
    out= x+inputs

    out=BatchNormalization()(out)
    return out
  return call

In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras import models



__all__ = ['get_efficient_unet_b0', 'get_efficient_unet_b1', 'get_efficient_unet_b2', 'get_efficient_unet_b3',
           'get_efficient_unet_b4', 'get_efficient_unet_b5', 'get_efficient_unet_b6', 'get_efficient_unet_b7',
           'get_blocknr_of_skip_candidates']


def get_blocknr_of_skip_candidates(encoder, verbose=False):
    """
    Get block numbers of the blocks which will be used for concatenation in the Unet.
    :param encoder: the encoder
    :param verbose: if set to True, the shape information of all blocks will be printed in the console
    :return: a list of block numbers
    """
    shapes = []
    candidates = []
    mbblock_nr = 0
    while True:
        try:
            mbblock = encoder.get_layer('blocks_{}_output_batch_norm'.format(mbblock_nr)).output
            shape = int(mbblock.shape[1]), int(mbblock.shape[2])
            if shape not in shapes:
                shapes.append(shape)
                candidates.append(mbblock_nr)
            if verbose:
                print('blocks_{}_output_shape: {}'.format(mbblock_nr, shape))
            mbblock_nr += 1
        except ValueError:
            break
    return candidates


In [None]:
def ResidualBlock():
  def call(inputs):
    x = inputs
    indim= inputs.shape[-1]
    residual = Conv2D(indim,kernel_size=(3,1),padding='same')(x)
    residual = BatchNormalization()(residual)
    residual = LeakyReLU(0.2)(residual)
    residual = Conv2D(indim,kernel_size=(1,3),padding='same')(residual)
    residual = BatchNormalization()(residual)
    residual = LeakyReLU(0.2)(residual)

    residual = Conv2D(indim,kernel_size=(3,1),padding='same')(residual)
    residual = BatchNormalization()(residual)
    residual = LeakyReLU(0.2)(residual)
    residual = Conv2D(indim,kernel_size=(1,3),padding='same')(residual)
    residual = BatchNormalization()(residual)
    residual = LeakyReLU(0.2)(residual)
    x        = BatchNormalization()(x)
    #residual = Dropout(0.2)(residual)
    out = x+ residual
    
    return out
  return call
#khoi giam kich thuoc skip connection
def dowsample_skip():
  def call(inputs):
      skip_out= Conv2D(inputs.shape[-1],kernel_size=3,strides=1,padding='same',kernel_initializer='he_normal',use_bias=False)(inputs)
      skip_out= BatchNormalization()(skip_out)
      skip_out= Activation('relu')(skip_out)
      skip_out = MaxPooling2D(pool_size=(2,2),strides=2)(skip_out)
      return skip_out
  return call

#xay dung khoi decoder: 
def Conv2DTranspose_block2(filters, transpose_kernel_size=(2, 2), upsample_rate=(2, 2),interpolation='bilinear', skip=None):
  def layer(input_tensor):
    x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate, padding='same',kernel_initializer = 'he_normal')(input_tensor)
    out = x
    if skip is not None :
      out = Concatenate()([x, skip])
    out=ResidualBlock()(out)
    #out=Dropout(0.2)(out)
    return out
  return layer


In [None]:
def get_efficient_unet_vs1(encoder, out_channels=1, block_type='upsampling', concat_input=True):
    MBConvBlocks = []

    #lay cac skip connection tu encoder
    skip_candidates = get_blocknr_of_skip_candidates(encoder)

    for mbblock_nr in skip_candidates:
        mbblock = encoder.get_layer('blocks_{}_output_batch_norm'.format(mbblock_nr)).output
        MBConvBlocks.append(mbblock)

    # delete the last block since it won't be used in the process of concatenation
    MBConvBlocks.pop()

    input_ = encoder.input
    head = encoder.get_layer('head_swish').output
    blocks = [input_] + MBConvBlocks + [head]
    #define decoder block
    UpBlock = Conv2DTranspose_block2
    
    #build decoder with double skip connection
    o = blocks.pop()
    o = BottleNeck1()(o)
    o = UpBlock(512, skip=blocks.pop())(o)

    o = UpBlock(256, skip=blocks.pop())(o)


    o = UpBlock(128, skip=blocks.pop())(o)

    o = UpBlock(64,  skip=blocks.pop())(o)

    if concat_input:
        o = UpBlock(32, skip=blocks.pop())(o)
    else:
        o = UpBlock(32)(o)
    o = Conv2D(3, (1, 1), padding='same', kernel_initializer=conv_kernel_initializer,use_bias=False)(o)
    o = BatchNormalization()(o)
    o = LeakyReLU(0.2)(o)
    o = Conv2D(out_channels,(1, 1), padding='same',activation='sigmoid')(o)
    model = models.Model(encoder.input, o)

    return o, MBConvBlocks, model


In [None]:
def get_efficient_unet_b0(input_shape, out_channels=2, pretrained=False, block_type='transpose', concat_input=True):
    """Get a Unet model with Efficient-B0 encoder
    :param input_shape: shape of input (cannot have None element)
    :param out_channels: the number of output channels
    :param pretrained: True for ImageNet pretrained weights
    :param block_type: "upsampling" to use UpSampling layer, otherwise use Conv2DTranspose layer
    :param concat_input: if True, input image will be concatenated with the last conv layer
    :return: an EfficientUnet_B0 model
    """
    encoder = get_efficientnet_b0_encoder(input_shape, pretrained=pretrained)
    model = get_efficient_unet_vs1(encoder, out_channels, block_type=block_type, concat_input=concat_input)
    return model
def get_efficient_unet_b1(input_shape, out_channels=2, pretrained=False, block_type='transpose', concat_input=True):
    """Get a Unet model with Efficient-B0 encoder
    :param input_shape: shape of input (cannot have None element)
    :param out_channels: the number of output channels
    :param pretrained: True for ImageNet pretrained weights
    :param block_type: "upsampling" to use UpSampling layer, otherwise use Conv2DTranspose layer
    :param concat_input: if True, input image will be concatenated with the last conv layer
    :return: an EfficientUnet_B0 model
    """
    encoder = get_efficientnet_b1_encoder(input_shape, pretrained=pretrained)
    model = get_efficient_unet_vs1(encoder, out_channels, block_type=block_type, concat_input=concat_input) 
    return model
def get_efficient_unet_b2(input_shape, out_channels=2, pretrained=False, block_type='transpose', concat_input=True):
    """Get a Unet model with Efficient-B0 encoder
    :param input_shape: shape of input (cannot have None element)
    :param out_channels: the number of output channels
    :param pretrained: True for ImageNet pretrained weights
    :param block_type: "upsampling" to use UpSampling layer, otherwise use Conv2DTranspose layer
    :param concat_input: if True, input image will be concatenated with the last conv layer
    :return: an EfficientUnet_B0 model
    """
    encoder = get_efficientnet_b2_encoder(input_shape, pretrained=pretrained)
    model = get_efficient_unet_vs1(encoder, out_channels, block_type=block_type, concat_input=concat_input)
    return model
def get_efficient_unet_b3(input_shape, out_channels=2, pretrained=False, block_type='transpose', concat_input=True):
    """Get a Unet model with Efficient-B0 encoder
    :param input_shape: shape of input (cannot have None element)
    :param out_channels: the number of output channels
    :param pretrained: True for ImageNet pretrained weights
    :param block_type: "upsampling" to use UpSampling layer, otherwise use Conv2DTranspose layer
    :param concat_input: if True, input image will be concatenated with the last conv layer
    :return: an EfficientUnet_B0 model
    """
    encoder = get_efficientnet_b3_encoder(input_shape, pretrained=pretrained)
    model = get_efficient_unet_vs1(encoder, out_channels, block_type=block_type, concat_input=concat_input)
    return model
def get_efficient_unet_b4(input_shape, out_channels=2, pretrained=False, block_type='transpose', concat_input=True):
    """Get a Unet model with Efficient-B0 encoder
    :param input_shape: shape of input (cannot have None element)
    :param out_channels: the number of output channels
    :param pretrained: True for ImageNet pretrained weights
    :param block_type: "upsampling" to use UpSampling layer, otherwise use Conv2DTranspose layer
    :param concat_input: if True, input image will be concatenated with the last conv layer
    :return: an EfficientUnet_B0 model
    """
    encoder = get_efficientnet_b4_encoder(input_shape, pretrained=pretrained)
    model = get_efficient_unet_vs1(encoder, out_channels, block_type=block_type, concat_input=concat_input)
    return model
def get_efficient_unet_b5(input_shape, out_channels=2, pretrained=False, block_type='transpose', concat_input=True):
    """Get a Unet model with Efficient-B0 encoder
    :param input_shape: shape of input (cannot have None element)
    :param out_channels: the number of output channels
    :param pretrained: True for ImageNet pretrained weights
    :param block_type: "upsampling" to use UpSampling layer, otherwise use Conv2DTranspose layer
    :param concat_input: if True, input image will be concatenated with the last conv layer
    :return: an EfficientUnet_B0 model
    """
    encoder = get_efficientnet_b5_encoder(input_shape, pretrained=pretrained)
    outputs1, X_skip, model = get_efficient_unet_vs1(encoder, out_channels, block_type=block_type, concat_input=concat_input)
    return model

## Double Mix-up

### Ver1

In [None]:
def decoder_block(inputs,n_filter,X_skip,skip=None):
    x= Conv2DTranspose(n_filter, (2,2), strides=(2, 2), padding='same',kernel_initializer = 'he_normal')(inputs)
    out = x
    if skip is not None :
      attention = conv_block(n_filter)(skip)
      out = Concatenate()([x,X_skip,attention])
    out = conv_block(n_filter)(out) #residual_block(out,n_filter) #conv_block(n_filter)(out)

    return out

def dow_block(inputs, kernel_size=(2,2),stride=(2,2)):
    out = MaxPooling2D(kernel_size, strides=stride)(inputs)
    return out


In [None]:
def encoderSegnet(input): #input_s=(128,128,1)
  #down_block = dow_block(input)
  #inp= Input(shape=input_s)
  o = input #inp
  nums_filter=[64,128,256,512,512]
  count=0
  for f in nums_filter[:-1]:
    count+=1
    o = conv_block(f,block_name='output_block_'+str(count))(o)
    o = dow_block(o)

  o = conv_block(nums_filter[-1],block_name='output_block_'+str(count+1))(o)
  model2 = Model(input,o)
  #o = Dropout(0.5)(o)
  return o, model2 #Model(input,o) #inp

In [None]:
list_skip = ["output_block_4", "output_block_3", "output_block_2", "output_block_1"]

In [None]:
'''def seg_net(input, list_skip = list_skip, X_skip=skip, out_channels=3): #input_shape= (192,288,2)
  o, encoder = encoderSegnet(input) #input_s = input_shape
  skip_connect=[encoder.get_layer(i).output for i in list_skip]
  num_filters = [512,256, 128, 64]

  #o = encoder.output
  o = ASPP(o,128)
  
  for i, f in enumerate(num_filters):
    o = decoder_block(f,X_skip[i],skip=skip_connect[i])(o)
  
  o = Conv2D(out_channels,(3, 3), padding='same', kernel_initializer='he_normal')(o)
  # yn = Activation('softmax')(o[...,:-1])
  # bn = o[...,-1:]
  # output = Concatenate()([yn,bn])
  if out_channels > 1 : 
    output = Activation('softmax', name = 'softmax')(o)
  else :
    output = Activation('sigmoid', name = 'sigmoid')(o)
  return output #Model(encoder.input,output)'''

"def seg_net(input, list_skip = list_skip, X_skip=skip, out_channels=3): #input_shape= (192,288,2)\n  o, encoder = encoderSegnet(input) #input_s = input_shape\n  skip_connect=[encoder.get_layer(i).output for i in list_skip]\n  num_filters = [512,256, 128, 64]\n\n  #o = encoder.output\n  o = ASPP(o,128)\n  \n  for i, f in enumerate(num_filters):\n    o = decoder_block(f,X_skip[i],skip=skip_connect[i])(o)\n  \n  o = Conv2D(out_channels,(3, 3), padding='same', kernel_initializer='he_normal')(o)\n  # yn = Activation('softmax')(o[...,:-1])\n  # bn = o[...,-1:]\n  # output = Concatenate()([yn,bn])\n  if out_channels > 1 : \n    output = Activation('softmax', name = 'softmax')(o)\n  else :\n    output = Activation('sigmoid', name = 'sigmoid')(o)\n  return output #Model(encoder.input,output)"

In [None]:
def residual_attention_concate_UNet(input):
    #input = Input(shape=input_shape, dtype='float', name='data')
    batchnorm1 = BatchNormalization()(input)
    conv1 = Conv2D(64, 3, padding = 'same')(batchnorm1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = Conv2D(64, 3,  padding = 'same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Swish()(conv1)
    conv1 = squeeze_excite_block(name_block='conv1')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3,  padding = 'same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = Conv2D(128, 3,  padding = 'same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Swish()(conv2)
    conv2 = squeeze_excite_block(name_block='conv2')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3,  padding = 'same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = Conv2D(256, 3,  padding = 'same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Swish()(conv3)
    conv3 = squeeze_excite_block(name_block='conv3')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3,  padding = 'same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = Conv2D(512, 3,  padding = 'same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Swish()(conv4)
    conv4 = squeeze_excite_block(name_block='conv4')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    pool4 = ASPP(pool4,1024)

    #bottle_neck
    conv5 = Conv2D(1024, (3, 3), activation=None, padding="same")(pool4)
    conv5 = residual_block(conv5, 1024)
    conv5 = residual_block(conv5, 1024)
    conv5 = Swish(name='conv5')(conv5)
    drop5 = Dropout(0.5)(conv5)

    return Model(input, drop5)

In [None]:
'''
skip_connections = []

    model = residual_attention_concate_UNet(inputs)
    names = ["conv1", "conv2", "conv3", "conv4"]
    for name in names:
        skip_connections.append(model.get_layer(name).output)

    output = model.get_layer("conv5").output
    return output, skip_connections

    num_filters = [64, 128, 256, 512]
    skip_connections = []
    x = inputs
'''

'\nskip_connections = []\n\n    model = residual_attention_concate_UNet(inputs)\n    names = ["conv1", "conv2", "conv3", "conv4"]\n    for name in names:\n        skip_connections.append(model.get_layer(name).output)\n\n    output = model.get_layer("conv5").output\n    return output, skip_connections\n\n    num_filters = [64, 128, 256, 512]\n    skip_connections = []\n    x = inputs\n'

In [None]:
def encoder1(inputs):
    num_filters = [64, 128, 256, 512]
    skip_connections = []
    x = inputs

    for i, f in enumerate(num_filters):
        x = conv_block(x, f)
        skip_connections.append(x)
        x = MaxPool2D((2, 2))(x)
    
    x = Conv2D(1024, (3, 3), activation=None, padding="same")(x)
    #x = residual_block(x, 1024)
    #x = residual_block(x, 1024)
    x = Swish()(x)
    x = Dropout(0.5)(x)

    return x, skip_connections

def decoder1(inputs, skip_connections):
    num_filters = [512,256, 128, 64]
    skip_connections.reverse()
    x = inputs

    for i, f in enumerate(num_filters):
        x = Conv2DTranspose(num_filters[i], (3, 3), strides=(2, 2), padding="same")(x)
        x = Concatenate()([x, skip_connections[i]])
        x = Dropout(0.2)(x)
        x = conv_block(x, f)
        #x = residual_block(x,num_filters[i])
        #x = residual_block(x,num_filters[i])
        x = Swish()(x)

    return x

In [None]:
def encoder2(inputs):
    num_filters = [64, 128, 256, 512]
    skip_connections = []
    x = inputs

    for i, f in enumerate(num_filters):
        x = conv_block(x, f)
        skip_connections.append(x)
        x = MaxPool2D((2, 2))(x)

    return x, skip_connections

def decoder2(inputs, skip_1, skip_2):
    num_filters = [512,256, 128, 64]
    skip_2.reverse()
    x = inputs

    for i, f in enumerate(num_filters):
        x = UpSampling2D((2, 2), interpolation='bilinear')(x)
        x = Concatenate()([x, skip_1[i], skip_2[i]])
        x = conv_block(x, f)

    return x

In [None]:
def build_model(shape):
    inputs = Input(shape)
    x, skip_1 = encoder1(inputs)
    x = ASPP(x, 64)
    x = decoder1(x, skip_1)
    outputs1 = output_block(x)

    x = inputs * outputs1

    x, skip_2 = encoder2(x)
    x = ASPP(x, 128)
    x = decoder2(x, skip_1, skip_2)
    outputs2 = output_block1(x)
    #outputs = Concatenate()([outputs1, outputs2])

    #outputs = Conv2D(3, (1,1), padding="same", activation="softmax")(outputs) 

    model = Model(inputs, outputs2)
    return model

generator = build_model((128, 128, 1))
generator.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 128, 128, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 576         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
swish (Swish)                   (None, 128, 128, 64) 0           batch_normalization[0][0]        
______________________________________________________________________________________________

### Ver 2

In [None]:
def encoder1(inputs):
    skip_connections = []

    model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_tensor=inputs)
    names = ["block1_conv2", "block2_conv2", "block3_conv4", "block4_conv4"]
    for name in names:
        skip_connections.append(model.get_layer(name).output)

    output = model.get_layer("block5_conv4").output
    return output, skip_connections

def decoder1(inputs, skip_connections):
    num_filters = [256, 128, 64, 32]
    skip_connections.reverse()
    x = inputs

    for i, f in enumerate(num_filters):
        x = UpSampling2D((2, 2), interpolation='bilinear')(x)
        x = Concatenate()([x, skip_connections[i]])
        x = conv_block(x, f)

    return x

def encoder2(inputs):
    num_filters = [32, 64, 128, 256]
    skip_connections = []
    x = inputs

    for i, f in enumerate(num_filters):
        x = conv_block(x, f)
        skip_connections.append(x)
        x = MaxPool2D((2, 2))(x)

    return x, skip_connections

def decoder2(inputs, skip_1, skip_2):
    num_filters = [256, 128, 64, 32]
    skip_2.reverse()
    x = inputs

    for i, f in enumerate(num_filters):
        x = UpSampling2D((2, 2), interpolation='bilinear')(x)
        x = Concatenate()([x, skip_1[i], skip_2[i]])
        x = conv_block(x, f)

    return x

In [None]:
def build_model(shape):
    inputs = Input(shape)
    x, skip_1 = encoder1(inputs)
    x = ASPP(x, 64)
    x = decoder1(x, skip_1)
    outputs1 = output_block(x)

    x = inputs * outputs1

    x, skip_2 = encoder2(x)
    x = ASPP(x, 64)
    x = decoder2(x, skip_1, skip_2)
    outputs2 = output_block(x)
    outputs = Concatenate()([outputs1, outputs2])

    outputs = Conv2D(3, (1,1), padding="same", activation="softmax")(outputs) 

    model = Model(inputs, outputs)
    return model

#generator = build_model((128, 128, 3))
#generator.summary()

In [None]:
def refinement():
  def call(y):
    y = Conv2D(64, 1,  padding="same")(y)
    y = Swish()(y)
    y = Conv2D(64, 3,  padding="same")(y)
    y = Swish()(y)
    y = Conv2D(64, 3,  padding="same")(y)
    y = Swish()(y)
    y = Conv2D(1, 3,  padding="same")(y)
    y = Swish()(y)
    y = squeeze_excite_block()(y)
    return y
  return call

def reverse_attention():
  def call(x,y):
    a = Activation('sigmoid')(y)
    a = -1*(a)+1
    x = Multiply()([a,x])
    y = y+refinement()(x)
    return y
  return call

In [None]:
x1 = conv1; x2 = conv2; x3 = conv3; x4 = conv4;
    x5 = Conv2D(1, (1,1), padding="same", activation=None)(conv9) #None  

    y5 = Resizing(x4.shape[1],x4.shape[2], interpolation='bilinear', crop_to_aspect_ratio=True)(x5) 
    y4 = reverse_attention()(x4,y5)
    _y4 = Resizing(128, 128, interpolation='bilinear', crop_to_aspect_ratio=True)(y4) 

    y4 = Resizing(x3.shape[1],x3.shape[2], interpolation='bilinear', crop_to_aspect_ratio=True)(y4) 
    y3 = reverse_attention()(x3,y4)
    _y3 = Resizing(128, 128, interpolation='bilinear', crop_to_aspect_ratio=True)(y3) 

    y3 = Resizing(x2.shape[1],x2.shape[2], interpolation='bilinear', crop_to_aspect_ratio=True)(y3) 
    y2 = reverse_attention()(x2,y3)
    _y2 = Resizing(128, 128, interpolation='bilinear', crop_to_aspect_ratio=True)(y2) 

    y2 = Resizing(x1.shape[1],x1.shape[2], interpolation='bilinear', crop_to_aspect_ratio=True)(y2) 
    y1 = reverse_attention()(x1,y2)
    _y1 = Resizing(128, 128, interpolation='bilinear', crop_to_aspect_ratio=True)(y1) 

    #_y1 = Conv2D(3, (1,1), padding="same", activation='softmax')(_y1)

    output = _y4, _y3, _y2, _y1 

In [None]:
 def tversky_kahneman(self, target, output, alpha=0.5, beta=0.5, gamma=4/3, smooth=1e-10):
    #output = tf.expand_dims(tf.argmax(output, axis=-1), axis = -1)
    #output = tf.cast(output, tf.float32)
    target = tf.one_hot(tf.squeeze(tf.cast(target, tf.uint8), axis=-1), depth=3)
    target = tf.cast(target, tf.float32)

    target = K.permute_dimensions(target, (3,1,2,0))
    output = K.permute_dimensions(output, (3,1,2,0))
    target_positive = K.flatten(target[...,1:]) #K.batch_flatten(target) #K.flatten(target)
    output_positive = K.flatten(output[...,1:]) #K.batch_flatten(output) #K.flatten(output)

    true_pos = K.sum(target_positive * output_positive, axis=1)
    true_neg = K.sum((1-target_positive) * (1-output_positive),axis=1)
    false_neg = K.sum(target_positive * (1-output_positive),axis=1)
    false_pos = K.sum((1-target_positive) * output_positive,axis=1)

    p = 0.5*(true_pos + true_neg)/(0.5*true_pos + 0.5*true_neg + alpha*false_pos + beta*false_neg)  
    p_gamma = K.pow(p,gamma) #p^gamma
    _p_gamma = K.pow(1-p, gamma) #(1-p)^gamma
    loss_tensor = _p_gamma/K.pow(p_gamma + _p_gamma, 1/gamma)
    loss = loss_tensor[0] + loss_tensor[1] + loss_tensor[2]

    return loss

In [None]:

def fcn_model(input_shape, num_classes, weights=None):
    #binary-classification => sigmoid+dice_coef, multi => sofmax+crossentropy
    if num_classes == 2:
        num_classes = 1
        loss = dice_coef_loss #tversky_kahneman
        activation = 'sigmoid'
    else:
        loss = tversky_kahneman #'categorical_crossentropy'
        activation = 'softmax'
    #tạo layers của model fcn
    data = Input(shape=input_shape, dtype='float', name='data')
    mvn0 = Lambda(mvn)(data)
    pad0 = ZeroPadding2D(padding = 5)(mvn0)

    conv1 = Conv2D(64, kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(pad0)
    mvn1 = Lambda(mvn)(conv1)

    conv2 = Conv2D(64,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias =True)(mvn1)
    mvn2 = Lambda(mvn)(conv2)

    conv3 = Conv2D(64,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias =True)(mvn2)
    mvn3 = Lambda(mvn)(conv3)

    mxp1 = MaxPooling2D(pool_size = (3,3), strides = (2,2), padding = 'valid')(mvn3)

    conv4 = Conv2D(128,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mxp1)
    mvn4 = Lambda(mvn)(conv4)

    conv5 = Conv2D(128,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn4)
    mvn5 = Lambda(mvn)(conv5)

    conv6 = Conv2D(128,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn5)
    mvn6 = Lambda(mvn)(conv6)

    conv7 = Conv2D(128,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn6)
    mvn7 = Lambda(mvn)(conv7)

    #drop1 = Dropout(rate = 0.5)(mvn7)

    mxp2 = MaxPooling2D(pool_size = (3,3), strides = (2,2), padding = 'valid')(mvn7)

    conv8 = Conv2D(256,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mxp2)
    mvn8 = Lambda(mvn)(conv8)

    conv9 = Conv2D(256,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn8)
    mvn9 = Lambda(mvn)(conv9)

    conv10 = Conv2D(256,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn9)
    mvn10 = Lambda(mvn)(conv10)

    conv11 = Conv2D(256,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn10)
    mvn11 = Lambda(mvn)(conv11)

    mxp3 = MaxPooling2D(pool_size = (3,3), strides = (2,2), padding = 'valid')(mvn11)
    
    drop2 = Dropout(rate = 0.5)(mxp3)

    conv12 = Conv2D(512,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(drop2)
    mvn12 = Lambda(mvn)(conv12)

    conv13 = Conv2D(512,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn12)
    mvn13 = Lambda(mvn)(conv13)

    conv14 = Conv2D(512,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn13)
    mvn14 = Lambda(mvn)(conv14)

    conv15 = Conv2D(512,kernel_size = (3,3), padding = 'same', activation = 'relu',  use_bias = True)(mvn14)
    mvn15 = Lambda(mvn)(conv15)

    drop3 = Dropout(rate = 0.5)(mvn15)

    score_conv15 = Conv2D(1, (1,1), strides = (1,1), use_bias = True, padding = 'valid')(drop3)

    upsample1 = Conv2DTranspose(1, (3,3), strides = (2,2), use_bias = False, padding = 'valid')(score_conv15)
    #Conv2DTranspose = Deconvolution : phép biến đổi đi ngược lại tích chập 
    #từ một thứ có hình dạng đầu ra của 1 phép tích chập sang 1 thứ có hình dạng đầu vào của nó.
    score_conv11 = Conv2D(1, (1,1), strides = (1,1), use_bias = True, padding = 'valid')(mvn11)

    crop1 = Lambda(crop)([upsample1, score_conv11])
    fuse1 = average([crop1, upsample1])

    upsample2 = Conv2DTranspose(1, (3,3), strides = (2,2), use_bias = False, padding = 'valid')(fuse1)

    score_conv7 = Conv2D(1, (1,1), strides = (1,1), use_bias = True, padding = 'valid')(mvn7)

    crop2 = Lambda(crop)([upsample2, score_conv7])
    fuse2 = average([crop2, upsample2])

    upsample3 = Conv2DTranspose(1, (3,3), strides = (2,2), use_bias = False, padding = 'valid' )(fuse2)

    crop3 = Lambda(crop)([data, upsample3])
    
    predict = Conv2D(1, (1,1), strides = (1,1), padding = 'valid', activation = 'sigmoid', use_bias = True)(crop3)
    model = models.Model(inputs=data, outputs=predict)
    if weights is not None:
        model.load_weights(weights)
    sgd = SGD(lr=0.01, momentum=0.9, nesterov=True)
    model.compile(optimizer=sgd, loss=loss,
                  metrics=[dice_coef, jaccard_coef])

    return model