In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.activations import *
from tensorflow.keras.regularizers import *
from tensorflow.keras import Input

from typing import Union , Tuple , Dict , List

In [None]:
def Encoder(input_tensors: tf.Tensor) -> tf.Tensor:
    def conv(input_tensors: tf.Tensor,
             filters: int) -> tf.Tensor:
        kernel_size = (3, 3)
        padding = 'same'
        activation_fn = 'relu'
        kernel_initializer = 'he_normal'
        x = input_tensors
        for _ in range(2):
            x = Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       padding=padding,
                       activation=activation_fn,
                       kernel_initializer=kernel_initializer)(x)
        return x

    n_filters = 64
    pool_size = (2, 2)
    x = input_tensors
    output_ = []
    for n in range(3):
        x = conv(input_tensors=x, filters=n_filters * 2 ** n)
        output_.append(x)
        x = MaxPooling2D(pool_size=pool_size, padding='same')(x)
    return x, output_

 
def DenseLayers(input_tensors: tf.Tensor) -> tf.Tensor:
    """
    Suppose there are only four dense blocks
    each of dense block contains two conv2d layers and one dropout layer
    """
    def DenseBlock(input_tensors: tf.Tensor) -> tf.Tensor:
        n_filters = 512
        padding = 'same'
        activation_fn = 'relu'
        kernel_size = (3, 3)
        kernel_initializer = 'he_normal'
        x = input_tensors
        for _ in range(2):
            x = Conv2D(filters=n_filters,
                       kernel_size=kernel_size,
                       padding=padding,
                       activation=activation_fn,
                       kernel_initializer=kernel_initializer)(x)
        return BatchNormalization()(x)
    out1 = DenseBlock(input_tensors=input_tensors)
    out2 = DenseBlock(input_tensors=out1)
    x = Concatenate(axis=-1)([out1, out2])
    out3 = DenseBlock(input_tensors=x)
    x = Concatenate(axis=-1)([out1, out3])
    return x

 

def Decoder(input_dense: tf.Tensor,
            input_process: List[tf.Tensor]) -> tf.Tensor:
    def conv(input_tensors: tf.Tensor,
             filters: int) -> tf.Tensor:
        x = input_tensors
        kernel_size = (3, 3)
        padding = 'same'
        up_conv_kernel_size = (3, 3)
        strides = (2, 2)
        kernel_initializer = 'he_normal'
        for _ in range(2):
            x = Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       padding=padding,
                       activation=activation_fn,
                       kernel_initializer=kernel_initializer)(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=up_conv_kernel_size,
                            strides=strides,
                            padding=padding,
                            kernel_initializer=kernel_initializer)(x)
        return BatchNormalization()(x)
    
    input_process = input_process[::-1]
    n_filters = 256
    kernel_size = (3, 3)
    strides = (2, 2)
    padding = 'same'
    activation_fn = 'relu'
    kernel_initializer = 'he_normal'
    x = Conv2DTranspose(filters=n_filters,
                        kernel_size=kernel_size,
                        strides=strides,
                        padding=padding,
                        kernel_initializer=kernel_initializer)(input_dense)
    x = BatchNormalization()(x)
    x = Activation(activation_fn)(x)
    
    
    """ shape of ConvLSTM : [batch_size, num_frames, width, height, color_channels] """
    x = Concatenate(axis=-1)([x, input_process[0]])
    x = tf.expand_dims(x, axis=1)
    x = ConvLSTM2D(filters=n_filters,
                   kernel_size=kernel_size,
                   padding=padding,
                   go_backwards=True,
                   kernel_initializer=kernel_initializer)(x)
    x = conv(input_tensors=x, filters=int(n_filters/2))
    x = Concatenate(axis=-1)([x, input_process[1]])
    x = tf.expand_dims(x, axis=1)
    x = ConvLSTM2D(filters=int(n_filters/4),
                   kernel_size=kernel_size,
                   padding=padding,
                   go_backwards=True,
                   kernel_initializer=kernel_initializer)(x)
    x = conv(input_tensors=x, filters=int(n_filters/4))
    x = Concatenate(axis=-1)([x, input_process[2]])
    x = tf.expand_dims(x, axis=1)
    return ConvLSTM2D(filters=int(n_filters/8),
                      kernel_size=(3, 3),
                      padding=padding,
                      go_backwards=True,
                      kernel_initializer=kernel_initializer)(x)


def Bi_Direc_ConvLSTM(height: int,
                      width: int,
                      color_channels: int,
                      num_classes: int) -> tf.keras.Model:
    input_shapes = (height, width, color_channels)
    inputs = Input(shape=input_shapes)
    final_encoder, intermediate_encoder = Encoder(input_tensors=inputs)
    dense = DenseLayers(input_tensors=final_encoder)
    final_decoder = Decoder(input_dense=dense,
                            input_process=intermediate_encoder)
    output = Conv2D(filters=num_classes,
                    kernel_size=(1, 1),
                    padding='same',
                    activation='softmax',
                    kernel_initializer='he_normal')(final_decoder)
    return Model(inputs=inputs, outputs=output, name='Bi-Directional_ConvLSTM_Unet')