In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.activations import *
from tensorflow.keras.regularizers import *
from tensorflow.keras import Input
from tensorflow.keras.applications.resnet import ResNet50 , ResNet101

from typing import Union , Tuple , Dict , List

In [None]:
def RCU(input_tensors: tf.float32,
        filters: int) -> tf.float32:
    # RCU : Residual Conv Unit
    x = input_tensors
    padding = 'same'
    kernel_size = (3, 3)
    for _ in range(2):
        x = ReLU()(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   padding=padding)(x)
    inputs = Conv2D(filters=filters,
                    kernel_size=(1, 1),
                    padding=padding)(input_tensors)
    return Add()([x, inputs])



def CRP(input_tensors: tf.float32,
        filters: int) -> tf.float32:
    # CRP : Chained Residual Pooling
    def pool_conv(input_tensor: tf.float32,
                  filters: int) -> tf.float32:
        pool_size = (5, 5)
        kernel_size = (3, 3)
        padding = 'same'
        x = MaxPooling2D(pool_size=pool_size,
                         strides=(1, 1),
                         padding=padding)(input_tensors)
        return Conv2D(filters=filters,
                      kernel_size=kernel_size,
                      padding=padding)(x)
    x = ReLU()(input_tensors)
    out = pool_conv(input_tensor=x,
                    filters=filters)
    return Add()([x, out])



def MRF(input_tensors_1: Union[tf.Tensor, None],
        input_tensors_2: Union[tf.Tensor, None],
        filters: int) -> tf.float32:
    def ResizeImage(input_tensor: tf.float32,
                    scale: int) -> tf.float32:
        _, height, width, _ = input_tensor.shape
        height, width = int(height*scale), int(width*scale)
        return tf.image.resize(images=input_tensor, size=(height, width))
    
    # MRF : Multi-resolution Fusion

    # @params :
    # input_tensors_1 : input_tensors with higher scale
    # input_tensors_2 : input_tensors with lower scale
    # remarks : normally scale(tensor_1) : scale(tensor_2) = 2 : 1
    kernel_size = (3, 3)
    padding = 'same'
    x1 = Conv2D(filters=filters,
                kernel_size=kernel_size,
                padding=padding)(input_tensors_1)
    if input_tensors_2 is None:
        return x1
    else:
        x2 = Conv2D(filters=filters,
                    kernel_size=kernel_size,
                    padding=padding)(input_tensors_2)
        x2 = ResizeImage(input_tensor=x2, scale=2)
        return Add()([x1, x2])
    

def RefineBlock(input_tensors_1: Union[tf.Tensor, None],
                input_tensors_2: Union[tf.Tensor, None],
                filters: int) -> tf.float32:
    if input_tensors_2 is None:
        x1 = RCU(input_tensors=input_tensors_1, filters=filters)
        x1 = MRF(input_tensors_1=x1, input_tensors_2=None, filters=filters)
        x1 = CRP(input_tensors=x1, filters=filters)
        return RCU(input_tensors=x1, filters=filters)
    else:
        x1 = RCU(input_tensors=input_tensors_1, filters=filters)
        x2 = RCU(input_tensors=input_tensors_2, filters=filters)
        x = MRF(input_tensors_1=x1, input_tensors_2=x2, filters=filters)
        x = CRP(input_tensors=x, filters=filters)
        return RCU(input_tensors=x, filters=filters)


def RefineNet(height: int,
              width: int,
              color_channels: int,
              num_classes: int) -> tf.float32:
    # Pretrained ResNet101 is used as backbone
    # Others pretrained such as VGG19 can be used
    # layers with name 'conv_{}_block2_2_relu' are used as inputs of refinenet
    input_shape = (height, width, color_channels)
    kernel_size = (1, 1)
    padding = 'same'
    filters = 256

    inputs = Input(shape=input_shape)
    resnet101 = ResNet101(include_top=False,
                          weights=None,
                          input_tensor=inputs)

    output1 = resnet101.get_layer('conv2_block2_2_relu').output
    output2 = resnet101.get_layer('conv3_block2_2_relu').output
    output3 = resnet101.get_layer('conv4_block2_2_relu').output
    output4 = resnet101.get_layer('conv5_block2_2_relu').output
    
    x = RefineBlock(input_tensors_1=output4,
                    input_tensors_2=None, filters=int(filters*2))
    x = RefineBlock(input_tensors_1=output3,
                    input_tensors_2=x, filters=filters)
    x = RefineBlock(input_tensors_1=output2,
                    input_tensors_2=x, filters=filters)
    x = RefineBlock(input_tensors_1=output1,
                    input_tensors_2=x, filters=filters)
    x = tf.image.resize(images=x, size=(height, width))
    x = Conv2D(filters=num_classes,
               kernel_size=kernel_size,
               padding=padding)(x)
    outputs = Conv2D(filters=num_classes,
                     kernel_size=kernel_size,
                     padding=padding,
                     activation='softmax')(x)
    return Model(inputs=inputs, outputs=outputs, name='RefineNet')