# Model Playground

In [1]:
import numpy as np

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

## Layer Class example

In [None]:
class Encoder(layers.Layer):
  """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

  def __init__(self,
               latent_dim=32,
               intermediate_dim=64,
               name='encoder',
               **kwargs):
    super(Encoder, self).__init__(name=name, **kwargs)
    self.dense_proj = layers.Dense(intermediate_dim, activation='relu')
    self.dense_mean = layers.Dense(latent_dim)
    self.dense_log_var = layers.Dense(latent_dim)
    self.sampling = Sampling()

  def call(self, inputs):
    x = self.dense_proj(inputs)
    z_mean = self.dense_mean(x)
    z_log_var = self.dense_log_var(x)
    z = self.sampling((z_mean, z_log_var))
    return z_mean, z_log_var, z

## Convolutional Stack Class

In [125]:
class ConvStack(layers.Layer):
  """Creates a Conv2D + BatchNorm stack.
  If input is (1,21,21), output is (1, 1, 64, 3, 3)"""

  def __init__(self,
               input_shape:tuple = (1,21,21),
               weight_decay=1e-5,
               filters=[32, 64, 64],
               kernel_sizes = [(5,5), (3,3), (3,3)],
               strides=[(2,2),(1,1),(2,2)],
               bias_init=0.1,
               output_activation=tf.nn.sigmoid,
               name='encoder',
               **kwargs):
    super(ConvStack, self).__init__(name=name, **kwargs)

    self.Conv2D_1 = layers.Conv2D(name="conv1",
                                  input_shape = input_shape,
                                 data_format='channels_first',
                                 filters=filters[0],
                                 kernel_size=kernel_sizes[0],
                                 strides=strides[0],
                                 kernel_initializer=tf.keras.initializers.GlorotNormal(),
                                 activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                                 activation=tf.nn.relu,
                                 )
    self.BN_1 = layers.BatchNormalization(axis=1, name="conv1_bn")
    self.Conv2D_2 = layers.Conv2D(name="conv2",
                                 data_format='channels_first',
                                 filters=filters[1],
                                 kernel_size=kernel_sizes[1],
                                 strides=strides[1],
                                 kernel_initializer=tf.keras.initializers.GlorotNormal(),
                                 activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                                 activation=tf.nn.relu,
                                 )
    self.BN_2 = layers.BatchNormalization(axis=1, name="conv2_bn")
    self.Conv2D_3 = layers.Conv2D(name="conv3",
                                 data_format='channels_first',
                                 filters=filters[2],
                                 kernel_size=kernel_sizes[2],
                                 strides=strides[2],
                                 kernel_initializer=tf.keras.initializers.GlorotNormal(),
                                 activity_regularizer=tf.keras.regularizers.l2(l=weight_decay),
                                 activation=tf.nn.relu,
                                 )
    self.BN_3 = layers.BatchNormalization(axis=1, name="conv3_bn")
    

  def call(self, inputs):
    """Construct the stack and pass the inputs thru"""
    stack = self.Conv2D_1(inputs)
    stack = self.BN_1(stack)
    stack = self.Conv2D_2(stack)
    stack = self.BN_2(stack)
    stack = self.Conv2D_3(stack)
    stack = self.BN_3(stack)
    # Add an empty dim to pass to Encoder
    return tf.expand_dims(stack, 0)

## test ConvStack class

In [132]:
inputs = tf.Variable(tf.random.uniform([1,1,21,21], -1, 1))
convstack = ConvStack()
convstack.call(inputs)
# stack.call(input for inputs in tf.unstack(inputs,num=24, axis=1))

<tf.Tensor: shape=(1, 1, 64, 3, 3), dtype=float32, numpy=
array([[[[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [1.54450592e-02, 0.00000000e+00, 0.00000000e+00]],

         [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],

         [[1.18545834e-02, 0.00000000e+00, 0.00000000e+00],
          [2.62208506e-02, 2.73919627e-02, 0.00000000e+00],
          [0.00000000e+00, 6.35615364e-02, 0.00000000e+00]],

         [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],

         [[1.18993811e-01, 3.44521329e-02, 0.00000000e+00],
          [6.33240268e-02, 6.98077232e-02, 8.45875684e-03],
          [8.84230435e-02, 4.33336906e-02, 3.68882082e-02]],

         [[8.09396431e-02, 0.000

In [None]:
class Encoder(layers.Layer):
    def __init__(lstm_layers=1, lstm_ksize_input=(3, 3), lstm_ksize_hidden=(5, 5),
                 lstm_use_peepholes=True, lstm_cell_clip=None, lstm_bn_input_hidden=False, 
                 lstm_bn_hidden_hidden=False, lstm_bn_peepholes=False,):
        
        # input to call() will be a list of tensors of shape=(1, 1, 64, 3, 3)
        # create a list of InputLayers
        # create a list of ConvLSTM cells
        # for each frame, create an InputLayer and corresponding ConvLSTM cell
        

In [121]:
class LSTMConv2DPredictionModel(tf.keras.Model):
    def __init__(self, weight_decay=1e-5, filters=[32, 64, 64], kernel_sizes = [(5,5), (3,3), (3,3)],
                 strides=[(2,2),(1,1),(2,2)], bias_init=0.1, output_activation=tf.nn.sigmoid,
                 bn_feature_enc=True, bn_feature_dec=True, 
                 lstm_layers=1, lstm_ksize_input=(3, 3), lstm_ksize_hidden=(5, 5),
                 lstm_use_peepholes=True, lstm_cell_clip=None, lstm_bn_input_hidden=False, 
                 lstm_bn_hidden_hidden=False, lstm_bn_peepholes=False,
                 scheduled_sampling_decay_rate=None,
                 main_loss=tf.keras.losses.MeanSquaredError(), alpha_main_loss=1.0,
                 alpha_gdl_loss=1.0, alpha_ssim_loss=0.0):
        assert len(filters) == len(kernel_sizes) and len(filters) == len(strides), "Encoder/Decoder configuration mismatch."
        
        self._filters = filters
        self._kernel_sizes = kernel_sizes
        self._strides = strides
        self._bias_init = bias_init
        self._output_activation = output_activation
        self._bn_feature_enc = bn_feature_enc
        self._bn_feature_dec = bn_feature_dec
        self.weight_decay = weight_decay
        
        # lstm
        self._lstm_layers = lstm_layers
        self._lstm_ksize_input = lstm_ksize_input
        self._lstm_ksize_hidden = lstm_ksize_hidden
        self._lstm_use_peepholes = lstm_use_peepholes
        self._lstm_cell_clip = lstm_cell_clip
        self._lstm_bn_input_hidden = lstm_bn_input_hidden
        self._lstm_bn_hidden_hidden = lstm_bn_hidden_hidden
        self._lstm_bn_peepholes = lstm_bn_peepholes
        
        # scheduled sampling
        self._scheduled_sampling_decay_rate = scheduled_sampling_decay_rate
        
        # main loss function, that will be combined with pisel-wise GDL
        self._main_loss = main_loss
        self._alpha_main_loss = alpha_main_loss
        self._alpha_gdl_loss = alpha_gdl_loss
        self._alpha_ssim_loss = alpha_ssim_loss
        
        super(LSTMConv2DPredictionModel, self).__init__(weight_decay)
        
    def call(self, inputs, targets, feeds,
              is_training, device_scope):
        input_shape = inputs.shape.as_list()
        target_shape = targets.shape.as_list()

        # Conv-Encoder
        conv_input_seq = []

        # convert from shape [bs, t, h, w, c] to list([bs, h, w, c])
        input_seq = tf.unstack(inputs, axis=1)

        for i in range(len(input_seq)):
            conv = self._conv_stack(input_seq[i], is_training)
            conv_input_seq.append(conv)
        
        # shape for convolved inputs that flows through our LSTMs
        feat_repr_shape = conv.shape.as_list()[1:]
        return conv_input_seq
        
    def _conv_stack(self, inputs, is_training):
        """Creates a 2D convolutional stack."""
        current_inputs = tf.expand_dims(inputs, 0)
        for i, (f, k, s) in enumerate(zip(self._filters, self._kernel_sizes, self._strides)):
            conv = layers.Conv2D(name="conv{}".format(i + 1),
                                 data_format='channels_first',
                                 filters=f,
                                 kernel_size=k,
                                 strides=s,
                                 kernel_initializer=tf.keras.initializers.GlorotNormal(),
#                                      bias_initializer=self._bias_init,
                                 activity_regularizer=tf.keras.regularizers.l2(l=self.weight_decay),
                                 activation=tf.nn.relu,
                                 )(current_inputs)
            if self._bn_feature_enc:
                conv = layers.BatchNormalization(axis=1, name="conv{}_bn".format(i + 1))(conv)

            current_inputs = conv
        
        return current_inputs

In [122]:
mod = LSTMConv2DPredictionModel()
inputs = tf.Variable(tf.random.uniform([24,1,21,21], -1, 1))
mod.compile()
mod.build(input_shape=[24,1,21,21])
mod.summary()

ValueError: Currently, you cannot build your model if it has positional or keyword arguments that are not inputs to the model, but are required for its `call` method. Instead, in order to instantiate and build your model, `call` your model on real tensor data with all expected call arguments.

In [115]:
test_stack = mod.inference(inputs=inputs, targets=inputs,
              feeds=None, is_training=False,
              device_scope=False)
test_stack
# test_stack.compile(loss='categorical_crossentropy',
#                   optimizer='adadelta',
#                   metrics=['mean_absolute_error'])
# test_stack.summary()

[<tf.Tensor: shape=(1, 64, 3, 3), dtype=float32, numpy=
 array([[[[2.29488928e-02, 1.28877088e-02, 0.00000000e+00],
          [0.00000000e+00, 7.54721388e-02, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],
 
         [[5.55545948e-02, 3.66725177e-01, 1.08393438e-01],
          [8.88168439e-02, 1.08020224e-01, 3.32204193e-01],
          [0.00000000e+00, 5.49434423e-01, 0.00000000e+00]],
 
         [[0.00000000e+00, 5.64309210e-02, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],
 
         [[0.00000000e+00, 2.34790519e-02, 0.00000000e+00],
          [0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
          [2.95179039e-02, 0.00000000e+00, 1.16316698e-01]],
 
         [[0.00000000e+00, 3.94059509e-01, 3.88746932e-02],
          [0.00000000e+00, 5.62738657e-01, 2.12506175e-01],
          [1.52152061e-01, 3.21079642e-01, 0.00000000e+00]],
 
         [[1.78171471e-01, 0.