In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.activations import *
from tensorflow.keras.regularizers import *
from tensorflow.keras import Input
from tensorflow.keras.initializers import LecunUniform
from tensorflow.keras.applications.resnet import ResNet50 , ResNet101

from typing import Union , Tuple , Dict , List

In [None]:
def AdaptAvgPooling(input_tensor: tf.float32,
                    output_size: int) -> tf.float32 :
    input_shape = input_tensor.shape
    stride = tf.floor(input_shape[1]/output_size)
    stride = tf.cast(stride, dtype=tf.float32)
    kernel = input_shape[1] - (output_size - 1) * stride
    return AveragePooling2D(pool_size=(kernel, kernel),
                            strides=stride)(input_tensor)


def Channel_Conv(input_tensor: tf.float32,
                 filters: int,
                 kernel_size: (int, int),
                 strides: (int, int) = (1, 1),
                 padding: str = 'same',
                 activation: str = 'relu',
                 dilation: (int, int) = (1, 1),
                 num_split: int = 8) -> tf.float32 :
    splits = tf.split(input_tensor, num_or_size_splits=num_split, axis=-1)
    num_filters = filters // num_split
    for idx in range(len(splits)):
        splits[idx] = Conv2D(filters=filters,
                             kernel_size=kernel_size,
                             strides=strides,
                             padding=padding,
                             activation=activation,
                             use_bias=False)(input_tensor)
        splits[idx] = BatchNormalization()(splits[idx])
    return tf.concat(splits, axis=-1)


def SELayer(input_tensor: tf.float32,
            decay_rate: int = 16) -> tf.float32 :
    channel = input_tensor.shape[-1]
    squeeze = AveragePooling2D(pool_size=(1, 1))(input_tensor)
    excitation = Dense(channel // decay_rate,
                       kernel_initializer='he_normal',
                       activation='relu')(squeeze)
    excitation = Dense(channel,
                       kernel_initializer='he_normal',
                       activation='sigmoid')(excitation)
    return excitation * input_tensor



def SE_Residual(input_tensor: tf.float32,
                channels: [int, int, int],
                decay_rate: int = 16) -> tf.float32 :
    x = Conv2D(filters=channels[0],
               kernel_size=(1, 1),
               strides=(1, 1),
               use_bias=False)(input_tensor)
    x = Channel_Conv(input_tensor=x,
                     filters=channels[1],
                     kernel_size=(3, 3),
                     strides=(1, 1))

    x = Conv2D(filters=channels[2],
               kernel_size=(1, 1),
               strides=(1, 1),
               use_bias=False)(x)
    x = BatchNormalization()(x)
    x = SELayer(input_tensor=x,
                decay_rate=decay_rate)
    
    if input_tensor.shape[-1] != channels[-1]:
        input_tensor = Conv2D(filters=channels[-1],
                              kernel_size=(1, 1),
                              strides=(1, 1),
                              activation=None,
                              use_bias=False)(input_tensor)
    return relu(input_tensor + x)


def SEResidual_MaxPool(num_repeat: int,
                       input_tensor: tf.float32,
                       channels) -> tf.float32:
    x = input_tensor
    for _ in range(num_repeat):
        x = SE_Residual(input_tensor=x,
                        channels=channels)
    return MaxPooling2D(pool_size=(3, 3),
                        strides=2,
                        padding='same')(x)
 

def SE_ResNet50(input_shape: Tuple[int, ...],
                num_classes: int):
    input_ = Input(shape=input_shape)
    x = Conv2D(filters=64,
               kernel_size=(7, 7),
               strides=2,
               padding='same')(input_)
    x = MaxPooling2D(pool_size=(3, 3),
                     strides=2,
                     padding='same')(x)
    n_channels = 64
    for n in [3, 4, 6, 3]:
        x = SEResidual_MaxPool(num_repeat=3,
                               input_tensor=x,
                               channels=[n_channels, n_channels, 2*n_channels])
        n_channels = 2*n_channels

    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes,
                    activation='softmax',
                    use_bias=False)(x)
    return Model(inputs=input_, outputs=outputs, name='SE_Resnet50')