# ConvLSTM functional encoder/decoder

In [1]:
import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
print("tensorflow version:",tf.__version__)

tensorflow version: 2.2.0-dlenv


## Input dimension variables

In [9]:
frames = 24
channels = 1
pixels_x = 21
pixels_y = 21

## Define the model

NOTE: Old versions of Tensorflow (<2.2) may not have support for initial_state on ConvLSTM2D cells. 
Please follow the instructions here: https://stackoverflow.com/questions/50253138/convlstm2d-initial-state-assertion-error to fix it.

In [4]:
model_name = 'full_stack_'+str(frames)+'f_'+str(channels)+'c_'+str(pixels_x)+'x_'+str(pixels_y)+'y'
print(model_name)

full_stack_24f_1c_21x_21y


# full stack model:
1. split inputs into values (inputA) and targets (inputB)

2. Run each through a deep convolutional layer to reduce dimensionality

3. run conv'd values through encoder

4. pass encoder hidden and cell states to decoder, also pass conv'd targets as decoder input

5. get decoder output and run it through deconvolution network to reassemble prediction

In [11]:
### Conv Stack Vars
input_shape = (channels,pixels_x,pixels_y)
weight_decay=1e-5
filters=[8, 16, 16]
kernel_sizes = [(5,5), (3,3), (3,3)]
strides=[(2,2),(1,1),(2,2)]
bias_init=0.1
output_activation=tf.nn.sigmoid,

########### 1. INPUT PARSING ###########

inputs = layers.Input(name="model_input",
                      shape = (2*frames, channels, pixels_x, pixels_y))

inputA, inputB = tf.split(inputs, 2, axis=1, num=None, name='split')

inputA = layers.GaussianNoise(0.1)(inputA)

########### 2a. CONV A (input t=0) ################

name = "convA"

Conv2D_1 = layers.Conv2D(name=name+"1",
                         data_format='channels_first',
                         filters=filters[0],
                         kernel_size=kernel_sizes[0],
                         strides=strides[0],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_1 = layers.BatchNormalization(axis=1, name=name+"1_bn")
Conv2D_2 = layers.Conv2D(name=name+"2",
                         data_format='channels_first',
                         filters=filters[1],
                         kernel_size=kernel_sizes[1],
                         strides=strides[1],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_2 = layers.BatchNormalization(axis=1, name=name+"2_bn")
Conv2D_3 = layers.Conv2D(name=name+"3",
                         data_format='channels_first',
                         filters=filters[2],
                         kernel_size=kernel_sizes[2],
                         strides=strides[2],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_3 = layers.BatchNormalization(axis=1, name=name+"3_bn")


stack = Conv2D_1(inputA)
stack = BN_1(stack)
stack = Conv2D_2(stack)
stack = BN_2(stack)
stack = Conv2D_3(stack)
convA_output = BN_3(stack)

###### 2b. CONV B (input t=1, aka target) ##############

name = "convB"

Conv2D_1B = layers.Conv2D(name=name+"1",
                         data_format='channels_first',
                         filters=filters[0],
                         kernel_size=kernel_sizes[0],
                         strides=strides[0],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_1B = layers.BatchNormalization(axis=1, name=name+"1_bn")
Conv2D_2B = layers.Conv2D(name=name+"2",
                         data_format='channels_first',
                         filters=filters[1],
                         kernel_size=kernel_sizes[1],
                         strides=strides[1],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_2B = layers.BatchNormalization(axis=1, name=name+"2_bn")
Conv2D_3B = layers.Conv2D(name=name+"3",
                         data_format='channels_first',
                         filters=filters[2],
                         kernel_size=kernel_sizes[2],
                         strides=strides[2],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
BN_3B = layers.BatchNormalization(axis=1, name=name+"3_bn")


stackB = Conv2D_1B(inputB)
stackB = BN_1B(stackB)
stackB = Conv2D_2(stackB)
stackB = BN_2B(stackB)
stackB = Conv2D_3B(stackB)
convB_output = BN_3(stackB)

####################################
####### ENCODER-DEC0DER ############

########## 3. ENCODER ##############

# first time-step
i = 0
# get input_images and output_images as one tensor
encoder_input = tf.expand_dims(convA_output, 0)
encoder_cell_1 = layers.ConvLSTM2D(name="encoder{}".format(i+1),
                                   filters = filters[-1],
                                   kernel_size=(5,5),
                                   padding='same',
                                   data_format='channels_first',
                                   return_sequences=True,
                                   return_state=True)
_, state_h, state_c = encoder_cell_1(encoder_input)

encoder_states = [state_h, state_c]

##### 4. DECODER #####

decoder_input = tf.expand_dims(convB_output, 0)
decoder_cell_1 = layers.ConvLSTM2D(name="decoder{}".format(i+1),
                                   filters = filters[-1],
                                   kernel_size=(5,5),
                                   padding='same',
                                   data_format='channels_first',
                                   return_sequences=True,
                                   return_state=True)
decoder_output, _, _ = decoder_cell_1(decoder_input, initial_state = encoder_states)
reshaped_decoder_output = tf.reshape(decoder_output,shape=[-1, filters[-1], 3, 3])


#################################
###### 5. DECONV STACK ##########

name = "deconv"
rev_filters = filters[::-1]
rev_filters = rev_filters[1:] + [channels]
rev_kernel_sizes = kernel_sizes[::-1]
rev_strides = strides[::-1]

deConv2D_1 = layers.Conv2DTranspose(name=name+"1",
                         data_format='channels_first',
                         filters=rev_filters[0],
                         kernel_size=rev_kernel_sizes[0],
                         strides=rev_strides[0],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
deBN_1 = layers.BatchNormalization(axis=1, name=name+"1_bn")
deConv2D_2 = layers.Conv2DTranspose(name=name+"2",
                         data_format='channels_first',
                         filters=rev_filters[1],
                         kernel_size=rev_kernel_sizes[1],
                         strides=rev_strides[1],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
deBN_2 = layers.BatchNormalization(axis=1, name=name+"2_bn")
deConv2D_3 = layers.Conv2DTranspose(name=name+"3",
                         data_format='channels_first',
                         filters=rev_filters[2],
                         kernel_size=rev_kernel_sizes[2],
                         strides=rev_strides[2],
                         kernel_initializer=tf.keras.initializers.GlorotNormal(),
                         activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                         activation="relu",
                         )
deBN_3 = layers.BatchNormalization(axis=1, name=name+"3_bn")

# reshaped = reshaper(decoder_output)
stack = deConv2D_1(reshaped_decoder_output)
stack = deBN_1(stack)
stack = deConv2D_2(stack)
stack = deBN_2(stack)
stack = deConv2D_3(stack)
stack = deBN_3(stack)
deconv_output = tf.expand_dims(stack, 0)

#######################

full_model = tf.keras.Model(name="Full_stack",
                       inputs = inputs,
                       outputs = deconv_output)
full_model.summary()

ValueError: Input 0 of layer convA1 is incompatible with the layer: expected ndim=4, found ndim=5. Full shape received: [None, 24, 1, 21, 21]

## Compile the model

In [None]:
full_model.compile(loss='KLDivergence',
                  optimizer='adadelta',
                  metrics=['accuracy', 'mean_absolute_error'])

## Pickle the model

In [None]:
tf.keras.models.save_model(
    model = full_model,
    filepath = '../models/'+model_name+'.h5',
    overwrite=True,
    include_optimizer=True,
    save_format='tf',
    signatures=None
)

In [None]:
inputs = tf.Variable(tf.random.uniform([1,1,8,3,3], -1, 1))
inputs.shape

In [None]:
tf.squeeze(inputs, axis=0).shape

In [None]:
rev_filters = filters[::-1]
rev_filters = rev_filters[1:] + [channels]
rev_ksizes = kernel_sizes[::-1]
rev_strides = strides[::-1]
print(rev_filters)
print(rev_ksizes)
print(rev_strides)