In [6]:
import tensorflow as tf
from tensorflow.keras import layers, models

In [None]:
def se_block(x, r=16):
    c = int(x.shape[-1])
    s = layers.GlobalAveragePooling2D()(x)               
    print('GlobalAveragePooling2D.shape:', s.shape)
    s = layers.Dense(units= c//r, activation='relu')(s) 
    print('Dense1.shape:', s.shape)
    s = layers.Dense(units=c, activation='sigmoid')(s)           
    print('Dense2.shape:', s.shape)
    s = layers.Reshape((1, 1, c))(s)                       
    print('Reshape.shape:', s.shape)
    return layers.Multiply()([x, s])
                  

def bottleneck_se(x, stride=2):
    shortcut = x
    # projection shortcut
    if stride != 1:
        shortcut = layers.Conv2D(filters=256, kernel_size=1, strides=stride, use_bias=False)(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)
    print('shortcut.shape:', shortcut.shape)
    
    x = layers.Conv2D(filters=64, kernel_size=1, use_bias=False)(x)
    print('x1.shape:', x.shape)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters=64, kernel_size=3, strides=stride, padding='same', use_bias=False)(x)
    print('x2.shape:', x.shape)    
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters=256, kernel_size=1, use_bias=False)(x)
    print('x3.shape:', x.shape)    
    x = layers.BatchNormalization()(x)

    x = se_block(x)  
    print('x4.shape:', x.shape)

    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    return x

inputs = layers.Input((112, 112, 64))    
outputs = bottleneck_se(inputs, stride=2)
model = models.Model(inputs, outputs)
model.summary()

shortcut.shape: (None, 56, 56, 256)
x1.shape: (None, 112, 112, 64)
x2.shape: (None, 56, 56, 64)
x3.shape: (None, 56, 56, 256)
GlobalAveragePooling2D.shape: (None, 256)
Dense1.shape: (None, 16)
Dense2.shape: (None, 256)
Reshape.shape: (None, 1, 1, 256)
x4.shape: (None, 56, 56, 256)
