In [4]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import keras
import tensorflow as tf
from keras import Model
from keras.layers import Layer
from keras import layers, regularizers, optimizers, losses, metrics, callbacks

In [69]:
class SqueezeExcite(Layer):
    def __init__(self, ratio=4, **kwargs):
        super(SqueezeExcite, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        channels = input_shape[-1]
        self.squeeze = layers.GlobalAveragePooling2D()
        self.excite1 = layers.Dense(channels // self.ratio, activation='relu')
        self.excite2 = layers.Dense(channels, activation='sigmoid')
        super(SqueezeExcite, self).build(input_shape)
    
    def call(self, inputs):
        x = self.squeeze(inputs)
        x = self.excite1(x)
        x = self.excite2(x)
        x = layers.Reshape((1, 1, x.shape[-1]))(x)
        s = layers.Multiply()([inputs, x])
        return layers.Add()([inputs, s])
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [139]:
def encoder_layer(inp, filt, k=3, pad='same', drop=0.1, pool=(2,2)):
    _ = layers.Conv2D(filt, k, activity_regularizer=regularizers.l1(1e-6), padding=pad)(inp)
    _ = SqueezeExcite()(_)
    _ = layers.GroupNormalization(groups=-1)(_)
    _ = layers.PReLU()(_)
    _ = layers.MaxPooling2D(pool)(_)
    _ = layers.SpatialDropout2D(drop)(_)
    return _

In [165]:
def recurrent_decoder(z_input, residuals, rnn_filters=[16,64,256], previous_timestep=None, 
                      dropout=0.1, leaky_slope=0.3, out_channels:int=2):
    
    def recurrent_step(inp, filt, res, kern=3, pad='same', drop=dropout):
        y = layers.ConvLSTM2D(filt, kern, padding=pad)(inp)
        y = layers.GroupNormalization(groups=-1)(y)
        y = layers.LeakyReLU(leaky_slope)(y)
        y = layers.Conv2DTranspose(filt, kern, padding=pad, strides=2)(y)
        y = layers.SpatialDropout2D(drop)(y)
        y = layers.Concatenate()([y, res])
        y = layers.Conv2D(filt, kern, padding=pad)(y)
        y = layers.Activation('sigmoid')(y)
        y = tf.expand_dims(y,1)
        return y
    
    def recurrent_last(inp, filt, kern=3, pad='same', drop=dropout):
        y = layers.ConvLSTM2D(filt, kern, padding=pad)(inp)
        y = layers.GroupNormalization(groups=-1)(y)
        y = layers.LeakyReLU(leaky_slope)(y)
        y = layers.Conv2DTranspose(filt, kern, padding=pad, strides=2)(y)
        y = layers.SpatialDropout2D(drop)(y)
        y = layers.Conv2D(out_channels, kern, padding=pad)(y)
        y = layers.Activation('sigmoid')(y)
        y = tf.expand_dims(y, 1)
        return y
    
    _ = tf.expand_dims(z_input, 1)
    _ = recurrent_step(_, rnn_filters[0], residuals[0])
    _ = recurrent_step(_, rnn_filters[1], residuals[1])
    _ = recurrent_last(_, rnn_filters[2])
    
    if previous_timestep != None:
        _ = layers.Concatenate(axis=1)([previous_timestep, _])
    
    return _

In [166]:
def make_model():
    inputs = layers.Input(shape=(160, 160, 5))

    # Encoder
    x1 = encoder_layer(inputs, 16)
    x2 = encoder_layer(x1, 64)
    x3 = encoder_layer(x2, 256)

    # Decoder
    _ = recurrent_decoder(x3, [x2, x1])

    outputs = x3
    return Model(inputs=inputs, outputs=outputs)

In [167]:
temp = np.random.normal(size=(1, 160, 160, 5)).astype(np.float32)
model = make_model()
print('# parameters: {:,}'.format(model.count_params()))

ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:

```
x = Input(...)
...
tf_fn(x)  # Invalid.
```

What you should do instead is wrap `tf_fn` in a layer:

```
class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)
```


In [154]:
x1 = encoder_layer(temp, 16)
x2 = encoder_layer(x1, 64)
x3 = encoder_layer(x2, 256)
print('x1: {} | x2: {} | x3: {}'.format(x1.shape, x2.shape, x3.shape))

zz = tf.expand_dims(x3, 1)
y1 = recurrent_step(zz, 256, x2)
y1 = recurrent_step(y1, 64, x1)
y1 = recurrent_last(y1, 16)
print('y1: {}'.format(y1.shape))

x1: (1, 80, 80, 16) | x2: (1, 40, 40, 64) | x3: (1, 20, 20, 256)
y1: (1, 1, 160, 160, 2)
