In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.activations import *
from tensorflow.keras.regularizers import *
from tensorflow.keras import Input

from typing import Union , Tuple , Dict , List

In [None]:

'''
References :
1. https://arxiv.org/abs/2010.11929
'''
def conv_bn(input_tensors: tf.float32,
            filters: int,
            kernel_size: Tuple[int],
            activation: str = 'relu',
            padding: str = 'same') -> tf.float32:
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation=activation,
               padding=padding)(input_tensors)
    return x


def conv2dTrans_bn(input_tensors: tf.float32,
                   filters: int,
                   kernel_size: Tuple[int],
                   activation: str = 'relu',
                   padding: str = 'same') -> tf.float32:
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation=activation,
                        padding=padding)(input_tensors)
    return x


def ResPath(input_tensors: tf.float32,
            filters: int,
            n_repeats: int,
            activation: str = 'relu',
            padding: str = 'same') -> tf.float32:

    '''
    ResPath : is used to replace traditional skip connection in Unet

    It is a combination of two conv2D layers with filter 3x3 and 1x1 respectively
    and repeating the combinations in [4,3,2,1] times

    @params : n_repeats : nums of replicating of combinations
    ---> replicating Add{ conv2d(3x3) , conv2d(1x1) } in several times
    '''

    def conv(input_tensors: tf.float32,
             filters: int,
             activation: str = 'relu',
             padding: str = 'same'):
        conv_3x3 = conv_bn(input_tensors=input_tensors,
                           filters=filters,
                           kernel_size=(3, 3),
                           activation=activation,
                           padding=padding)

        conv_1x1 = conv_bn(input_tensors=input_tensors,
                           filters=filters,
                           kernel_size=(1, 1),
                           activation=activation,
                           padding=padding)
        return Add()([conv_3x3, conv_1x1])

    x = input_tensors
    for _ in range(n_repeats):
        x = conv(input_tensors=x,
                 filters=filters,
                 activation=activation,
                 padding=padding)
    return x


def num_feature_maps(num_features: int,
                     alpha: float) -> int:
    '''
    num_features : numbers of feature maps in U-net , [64,128,256,512]
    decrease_rate : propose in paper , [6,3,2] are proposed in paper
    @output : floor(alpha * num_features / decrease_rate)
    '''
    return int(alpha * num_features)



def ResBlock(input_tensors: tf.float32,
             filters: int,
             activation: str = 'relu',
             padding: str = 'same') -> tf.float32:

    """
    ResBolck is used to replace the conv2d layers in Unet
    It is combinations of three differences kernel_size conv2D layers in parallel -> concat( 3x3 , 5x5 , 7x7)
    or it can be simplified by concat. of three conv2D with kernel_size 3x3 and added a residual connection (conv2D with kernel_size 1x1)
    --->  ResBlock : concat{three conv2D 3x3} + conv2D(input_tensor,with kernel_size 1x1) or concat( 3x3 , 5x5 , 7x7)

    """
    x = input_tensors
    decrease_rate = [0.167, 0.333, 0.5]
    result = []
    for idx in range(len(decrease_rate)):
        n_filters = num_feature_maps(num_features=filters,
                                     alpha=decrease_rate[idx])
        x = conv_bn(input_tensors=x,
                    filters=n_filters,
                    kernel_size=(3, 3),
                    activation='relu',
                    padding=padding)
        result.append(x)

    n_filters = sum([int(filters * decrease_rate[idx])
                    for idx in range(len(decrease_rate))])
    x4 = conv_bn(input_tensors=input_tensors,
                 filters=n_filters,
                 kernel_size=(1, 1),
                 activation='relu',
                 padding=padding) 

    x = Concatenate(axis=-1)(result)
    output = Add()([x, x4])
    output = Activation('relu')(output)
    return BatchNormalization()(output)



def MultiResUnet(height: int,
                 width: int,
                 color_channels: int,
                 num_classes: int) -> tf.float32:
    input_shape = (height, width, color_channels)
    inputs = Input(shape=input_shape)
    out = conv_bn(input_tensors=inputs,
                  filters=32,
                  kernel_size=(3, 3),
                  activation='relu',
                  padding='same') 

    pooling_result = []
    for n in range(5):
        x = ResBlock(input_tensors=out,
                     filters=32 * 2**n,
                     activation='relu',
                     padding='same')
        out = MaxPooling2D(pool_size=(2, 2))(x)
        pooling_result.append(out)

    pooling_result = pooling_result[::-1][1:]
    for idx in range(len(pooling_result)):
        n_filters = pooling_result[idx].shape[-1]
        x_1 = conv2dTrans_bn(input_tensors=x,
                             filters=n_filters,
                             kernel_size=(2, 2),
                             activation='relu',
                             padding='same')
        x_2 = ResPath(input_tensors=pooling_result[idx],
                      filters=n_filters,
                      n_repeats=idx + 1,
                      activation='relu',
                      padding='same')
        x = Concatenate()([x_1, x_2])
        x = UpSampling2D()(x)
    x = BatchNormalization()(x)
    output = conv_bn(input_tensors=x,
                     filters=num_classes,
                     kernel_size=(1, 1),
                     activation='softmax',
                     padding='same')
    return Model(inputs=inputs, outputs=output, name='MultiResUnet')