In [1]:
# Import required libraries

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, Dropout, ConvLSTM2D, Activation
from tensorflow.keras.layers import concatenate, Conv3DTranspose, BatchNormalization

In [1]:
class Models:
    def ecoder_decoder_block(input_tensor):
        x1 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 2, 2), 
                    activation='relu', padding='same', data_format='channels_last')(input_tensor)
        x = Conv3DTranspose(filters=16, kernel_size=(2, 3, 3), strides=(1, 2, 2), 
                            padding='same', data_format='channels_last')(x1) 
        x = concatenate([input_tensor, x], axis=-1) 
        x = BatchNormalization()(x)
        x = Conv3D(filters=16, kernel_size=(1, 3, 3), strides=(1, 1, 1), 
                   activation='relu', padding='same', data_format='channels_last')(x)
        return x1, x

    def ecoder_decoder_cnn_lstm(input_dim, dp):
        input_layer = Input(shape=input_dim)
        seq0 = Conv3D(filters=16, kernel_size=(1, 3, 3), strides=(1, 1, 1),activation='relu', 
                      padding='same', data_format='channels_last')(input_layer)	

        # ecoder_decoder block 1
        seq1, seq12 = Models.ecoder_decoder_block(seq0)

        seq13 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 2, 2),
                       activation='relu', padding='same', data_format='channels_last')(seq12) 
        
        # ecoder_decoder 2
        seq2, seq22 = Models.ecoder_decoder_block(seq13)

        seq22 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 2, 2),activation='relu',
                       padding='same', data_format='channels_last')(seq22)
        
        # ecoder_decoder 3
        seq30, seq32 = Models.ecoder_decoder_block(seq22)
        seq3 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 2, 2), 
                      activation='relu', padding='same', data_format='channels_last')(seq32) 
        seq4 = ConvLSTM2D(filters=16, kernel_size=(3, 3), strides=(2, 2),
                          activation='relu', padding='same', return_sequences=True)(seq3) 
        
        
    #*************************** Up-sampling ************************************
    
        seq6 = Conv3DTranspose(filters=16, kernel_size=(2, 3, 3), 
                               strides=(1, 2, 2), padding='same', data_format='channels_last')(seq4)
        seq6 = concatenate([seq6, seq3], axis=-1)
        seq6 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 1, 1), padding='same', data_format='channels_last')(seq6)
        seq6 = BatchNormalization()(seq6)
        seq6 = Activation('relu')(seq6)
        seq6 = concatenate([seq6, seq30], axis=-1)       

        seq7 = Conv3DTranspose(filters=16, kernel_size=(2, 3, 3), 
                               strides=(1, 2, 2), padding='same', data_format='channels_last')(seq6) 
        seq7 = concatenate([seq7, seq22], axis=-1) 
        seq7 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 1, 1), padding='same', data_format='channels_last')(seq7)
        seq7 = BatchNormalization()(seq7)
        seq7 = Activation('relu')(seq7)
        seq7 = concatenate([seq7, seq2], axis=-1) 
    
        seq8 = Conv3DTranspose(filters=16, kernel_size=(2, 3, 3), 
                               strides=(1, 2, 2), padding='same', data_format='channels_last')(seq7)  
        seq8 = concatenate([seq8, seq13], axis=-1) 
        seq8 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 1, 1), padding='same', data_format='channels_last')(seq8)
    
        seq8 = BatchNormalization()(seq8)
        seq8 = Activation('relu')(seq8)
        seq8 = concatenate([seq8, seq1], axis=-1) 
        
        seq9 = Conv3DTranspose(filters=16, kernel_size=(2, 3, 3), 
                               strides=(1, 2, 2), padding='same', data_format='channels_last')(seq8) 
        seq9 = concatenate([seq9, seq0], axis=-1) 
        seq9 = Conv3D(filters=32, kernel_size=(1, 3, 3), strides=(1, 1, 1), padding='same', data_format='channels_last')(seq9)
        
        seq10 = ConvLSTM2D(filters=16, kernel_size=(3, 3), strides=(1, 1),
                          activation='relu', padding='same', return_sequences=True)(seq9)
        seq9 = BatchNormalization()(seq10)
        seq9 = Activation('relu')(seq9)

        seq91 = Dropout(dp)(seq9)

        output_layer = Conv3D(filters=1, kernel_size=(2, 3, 3), strides=(1, 1, 1),activation='sigmoid',padding='same', data_format='channels_last')(seq91) #240 x 320

        return Model(input_layer, output_layer)
    