In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [14]:
class SE_ResidualUnit_bottleneck(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)

        self.SE = [keras.layers.GlobalAvgPool2D(),
                   keras.layers.Flatten(),
                   keras.layers.Dense(filters//4, activation='relu'),
                   keras.layers.Dense(filters, activation='sigmoid'),
                   keras.layers.Reshape([1,1,filters])]

        self.block_layers = [keras.layers.Conv2D(filters//4, kernel_size=(1,1), strides=1, padding='same', use_bias=False),
                             keras.layers.BatchNormalization(),
                             self.activation,
                             keras.layers.Conv2D(filters//4, kernel_size=(3,4), strides=strides, padding='same', use_bias=False),
                             keras.layers.BatchNormalization(),
                             self.activation,
                             keras.layers.Conv2D(filters, kernel_size=(1,1), strides=1, padding='same', use_bias=False),
                             keras.layers.BatchNormalization()]
        
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [keras.layers.Conv2D(filters, kernel_size=(1,1), strides=strides, padding='same', use_bias=False),
                                keras.layers.BatchNormalization()]

    def call(self, x):
        inputs = tf.identity(x)
        x1 = tf.identity(x)
        x2 = tf.identity(x)
        
        for layer in self.block_layers:
            x1 = layer(x1)

        for layer in self.SE:
          x2 = layer(x2)
        
        x = x1*x2
        
        for layer in self.skip_layers:
            inputs = layer(inputs)
        
        
        return self.activation(x + inputs)

In [26]:
model = keras.Sequential()
model.add(SE_ResidualUnit_bottleneck(64))
model.add(SE_ResidualUnit_bottleneck(64))
model.add(SE_ResidualUnit_bottleneck(32, 2))
model.add(SE_ResidualUnit_bottleneck(32))
model.add(SE_ResidualUnit_bottleneck(16, 2))
model.add(SE_ResidualUnit_bottleneck(16))

model.compile()

In [27]:
model.build(input_shape=(10, 240, 320, 1))

In [28]:
x = np.random.random((10, 240, 320, 1))

In [29]:
model(x)

<tf.Tensor: shape=(10, 60, 80, 16), dtype=float32, numpy=
array([[[[0.00000000e+00, 0.00000000e+00, 3.31589341e-01, ...,
          2.84726508e-02, 4.04343009e-03, 0.00000000e+00],
         [2.45062611e-03, 0.00000000e+00, 6.90533698e-01, ...,
          4.48524617e-02, 4.83307056e-02, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 5.28574884e-01, ...,
          8.99763480e-02, 6.03114907e-03, 0.00000000e+00],
         ...,
         [0.00000000e+00, 0.00000000e+00, 3.06047976e-01, ...,
          1.11211412e-01, 0.00000000e+00, 0.00000000e+00],
         [7.54932815e-04, 0.00000000e+00, 4.45089012e-01, ...,
          1.38171231e-02, 1.48886191e-02, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 5.95401406e-01, ...,
          2.27043331e-02, 7.58567546e-03, 0.00000000e+00]],

        [[0.00000000e+00, 0.00000000e+00, 6.22441947e-01, ...,
          8.83543566e-02, 6.54606223e-02, 0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 1.45857438e-01, ...,
         