In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [25]:
class SE_ResidualUnit_bottleneck(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)

        self.SE = [keras.layers.GlobalAvgPool2D(),
                   keras.layers.Flatten(),
                   keras.layers.Dense(filters//4, activation='relu'),
                   keras.layers.Dense(filters, activation='sigmoid'),
                   keras.layers.Reshape([1,1,filters])]

        self.block_layers = [keras.layers.Conv2D(filters//4, kernel_size=(1,1), strides=1, padding='same', use_bias=False),
                             keras.layers.BatchNormalization(),
                             self.activation,
                             keras.layers.Conv2D(filters//4, kernel_size=(3,4), strides=strides, padding='same', use_bias=False),
                             keras.layers.BatchNormalization(),
                             self.activation,
                             keras.layers.Conv2D(filters, kernel_size=(1,1), strides=1, padding='same', use_bias=False),
                             keras.layers.BatchNormalization()]
        
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [keras.layers.Conv2D(filters, kernel_size=(1,1), strides=strides, padding='same', use_bias=False),
                                keras.layers.BatchNormalization()]

    def call(self, x):
        inputs = tf.identity(x)
        x1 = tf.identity(x)
        
        for layer in self.block_layers:
            x1 = layer(x1)

        x2 = tf.identity(x1)
        for layer in self.SE:
          x2 = layer(x2)
        
        prod_calibration = x1*x2
        
        for layer in self.skip_layers:
            inputs = layer(inputs)
        
        
        return self.activation(prod_calibration + inputs)

In [None]:
class SE_ResidualUnit_default(keras.layers.Layer):
    def __init__(self, filters, strides=1, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)

        self.SE = [keras.layers.GlobalAvgPool2D(),
                   keras.layers.Flatten(),
                   keras.layers.Dense(filters//4, activation='relu'),
                   keras.layers.Dense(filters, activation='sigmoid'),
                   keras.layers.Reshape([1,1,filters])]

        self.block_layers = [keras.layers.Conv2D(filters, kernel_size=(3,4), strides=strides, padding='same', use_bias=False),
                             keras.layers.BatchNormalization(),
                             self.activation,
                             keras.layers.Conv2D(filters, kernel_size=(3,4), strides=1, padding='same', use_bias=False),
                             keras.layers.BatchNormalization()]
        
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [keras.layers.Conv2D(filters, kernel_size=(1,1), strides=strides, padding='same', use_bias=False),
                                keras.layers.BatchNormalization()]

    def call(self, x):
        inputs = tf.identity(x)
        x1 = tf.identity(x)
        
        for layer in self.block_layers:
            x1 = layer(x1)

        x2 = tf.identity(x1)
        for layer in self.SE:
          x2 = layer(x2)
        
        prod_calibration = x1*x2
        
        for layer in self.skip_layers:
            inputs = layer(inputs)
        
        
        return self.activation(prod_calibration + inputs)

In [37]:
class SE_InceptionModule(keras.layers.Layer):
    def __init__(self, filters, activation='relu', **kwargs):
        super().__init__(**kwargs)

        SE_filters = filters[0] + filters[2] + filters[4] + filters[5]
        self.SE = [keras.layers.GlobalAvgPool2D(),
                   keras.layers.Flatten(),
                   keras.layers.Dense(SE_filters//4, activation='relu'),
                   keras.layers.Dense(SE_filters, activation='sigmoid'),
                   keras.layers.Reshape([1, 1, SE_filters])]

        self.activation = keras.activations.get('relu')
        
        self.conv_a    = keras.layers.Conv2D(filters[0], kernel_size=(1,1), padding='same', use_bias='false', strides=1)
        self.bn_a      = keras.layers.BatchNormalization()
        
        self.conv_b1   = keras.layers.Conv2D(filters[1], kernel_size=(1, 1), padding='same', use_bias='false', strides=1)
        self.conv_b2   = keras.layers.Conv2D(filters[2], kernel_size=(3, 4), padding='same', use_bias='false', strides=1)
        self.bn_b      = keras.layers.BatchNormalization()
        
        self.conv_c1   = keras.layers.Conv2D(filters[3], kernel_size=(1, 1), padding='same', use_bias='false', strides=1)
        self.conv_c2   = keras.layers.Conv2D(filters[4], kernel_size=(5, 6), padding='same', use_bias='false', strides=1)
        self.bn_c      = keras.layers.BatchNormalization()
        
        self.maxpool_d = keras.layers.MaxPooling2D(pool_size=(3, 4), padding='same', strides=1)
        self.conv_d    = keras.layers.Conv2D(filters[5], kernel_size=(1, 1), padding='same', use_bias='false', strides=1)
        self.bn_d      = keras.layers.BatchNormalization()
        
    def call(self, x):
        out1 = self.conv_a(x)
        out1 = self.bn_a(out1)
        out1 = self.activation(out1)
        
        out2 = self.conv_b1(x)
        out2 = self.conv_b2(out2)
        out2 = self.bn_b(out2)
        out2 = self.activation(out2)
        
        out3 = self.conv_c1(x)
        out3 = self.conv_c2(out3)
        out3 = self.bn_c(out3)
        out3 = self.activation(out3)
        
        out4 = self.maxpool_d(x)
        out4 = self.conv_d(out4)
        out4 = self.bn_d(out4)
        out4 = self.activation(out4)

        y = tf.concat([out1, out2, out3, out4], axis=3)
        y_calibration = tf.identity(y)

        for layer in self.SE:
          y_calibration = layer(y_calibration)
        
        return y_calibration * y