# Stack single-input model
convolve input, send to encoder, decoder, deconvolve decoder output

In [1]:
import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
print("tensorflow version:",tf.__version__)

tensorflow version: 2.2.0-dlenv


## Input dimension variables

In [2]:
frames = 1
channels = 1
pixels_x = 21
pixels_y = 21

## Define the model

NOTE: Old versions of Tensorflow (<2.2) may not have support for initial_state on ConvLSTM2D cells. 
Please follow the instructions here: https://stackoverflow.com/questions/50253138/convlstm2d-initial-state-assertion-error to fix it.

In [3]:
model_dir = "."
model_folder = "/"
model_name = 'stack_single_input_'+str(frames)+'F'
print(model_dir+model_folder+model_name)

./stack_single_input_1F


### Convolutional Stack A

In [5]:
### Conv Stack Vars
input_shape = (1,21,21)
weight_decay=1e-5
filters=[8, 16, 16]
kernel_sizes = [(5,5), (3,3), (3,3)]
strides=[(2,2),(1,1),(2,2)]
bias_init=0.1
output_activation=tf.nn.sigmoid,

########### INPUT PARSING ###########

inputs = layers.Input(name="model_input",
                      shape = (channels, pixels_x, pixels_y))
# (None, ch, x, y)
# inputs = tf.reshape(inputs,shape=[-1, channels, pixels_x, pixels_y])
inputA = layers.GaussianNoise(0.1)(inputs)

########### CONV A (input t=0) ################

name = "convA"

Conv2D_1 = layers.Conv2D(name=name+"1",
                         data_format='channels_first',
                         filters=filters[0],
                         kernel_size=kernel_sizes[0],
                         strides=strides[0],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_1 = layers.BatchNormalization(axis=1, name=name+"1_bn")
Conv2D_2 = layers.Conv2D(name=name+"2",
                         data_format='channels_first',
                         filters=filters[1],
                         kernel_size=kernel_sizes[1],
                         strides=strides[1],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_2 = layers.BatchNormalization(axis=1, name=name+"2_bn")
Conv2D_3 = layers.Conv2D(name=name+"3",
                         data_format='channels_first',
                         filters=filters[2],
                         kernel_size=kernel_sizes[2],
                         strides=strides[2],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_3 = layers.BatchNormalization(axis=1, name=name+"3_bn")


stack = Conv2D_1(inputA)
stack = BN_1(stack)
stack = Conv2D_2(stack)
stack = BN_2(stack)
stack = Conv2D_3(stack)
convA_output = BN_3(stack)

####################################
####### ENCODER-DEC0DER ############

############# ENCODER ##############

# first time-step
i = 0
# get input_images and output_images as one tensor
encoder_input = tf.expand_dims(convA_output, 0)
encoder_cell_1 = layers.ConvLSTM2D(name="encoder{}".format(i+1),
                                   filters = filters[-1],
                                   kernel_size=(5,5),
                                   padding='same',
                                   data_format='channels_first',
                                   return_sequences=True,
                                   return_state=True)
_, state_h, state_c = encoder_cell_1(encoder_input)

encoder_states = [state_h, state_c]

##### DECODER #####

decoder_input = tf.expand_dims(convA_output, 0)
decoder_cell_1 = layers.ConvLSTM2D(name="decoder{}".format(i+1),
                                   filters = filters[-1],
                                   kernel_size=(5,5),
                                   padding='same',
                                   data_format='channels_first',
                                   return_sequences=True,
                                   return_state=True)
decoder_output, _, _ = decoder_cell_1(decoder_input, initial_state = encoder_states)
reshaped_decoder_output = tf.reshape(decoder_output,shape=[-1, filters[-1], 3, 3])
#################################
######## DECONV STACK ###########

name = "deconv"
rev_filters = filters[::-1]
rev_filters = rev_filters[1:] + [channels]
rev_kernel_sizes = kernel_sizes[::-1]
rev_strides = strides[::-1]

deConv2D_1 = layers.Conv2DTranspose(name=name+"1",
                         data_format='channels_first',
                         filters=rev_filters[0],
                         kernel_size=rev_kernel_sizes[0],
                         strides=rev_strides[0],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
deBN_1 = layers.BatchNormalization(axis=1, name=name+"1_bn")
deConv2D_2 = layers.Conv2DTranspose(name=name+"2",
                         data_format='channels_first',
                         filters=rev_filters[1],
                         kernel_size=rev_kernel_sizes[1],
                         strides=rev_strides[1],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
deBN_2 = layers.BatchNormalization(axis=1, name=name+"2_bn")
deConv2D_3 = layers.Conv2DTranspose(name=name+"3",
                         data_format='channels_first',
                         filters=rev_filters[2],
                         kernel_size=rev_kernel_sizes[2],
                         strides=rev_strides[2],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
deBN_3 = layers.BatchNormalization(axis=1, name=name+"3_bn")

# reshaped = reshaper(decoder_output)
stack = deConv2D_1(reshaped_decoder_output)
stack = deBN_1(stack)
stack = deConv2D_2(stack)
stack = deBN_2(stack)
stack = deConv2D_3(stack)
stack = deBN_3(stack)
deconv_output = tf.expand_dims(stack, 0)

#######################

full_model = tf.keras.Model(name="Full_stack",
                       inputs = inputs,
                       outputs = deconv_output)
full_model.summary()

Model: "Full_stack"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
model_input (InputLayer)        [(None, 1, 21, 21)]  0                                            
__________________________________________________________________________________________________
gaussian_noise (GaussianNoise)  (None, 1, 21, 21)    0           model_input[0][0]                
__________________________________________________________________________________________________
convA1 (Conv2D)                 (None, 8, 9, 9)      208         gaussian_noise[0][0]             
__________________________________________________________________________________________________
convA1_bn (BatchNormalization)  (None, 8, 9, 9)      32          convA1[0][0]                     
_________________________________________________________________________________________

## Save the model

In [6]:
tf.keras.models.save_model(
    model = full_model,
    filepath = model_dir+model_folder+model_name+'.h5',
    overwrite=True,
    include_optimizer=True,
    save_format='tf',
    signatures=None
)